MatMul simplification, threading strategy improvements

remove MatMul f32 special case (smaller code),
types: Add u32/u64 for use by Activations
move renamed ParallelismStrategy to threading_context so can pass ctx
ensure worker index is unique across clusters
matmul.h: const member functions for renamed policy classes (easier to call)
PiperOrigin-RevId: 802848086
This commit is contained in:
Jan Wassenberg 2025-09-03 21:44:39 -07:00 committed by Copybara-Service
parent 74ffe079c4
commit 7263ab8445
13 changed files with 514 additions and 478 deletions

View File

@ -191,12 +191,13 @@ constexpr bool SupportsPointerArithmetic() {
return !IsNuqStream<Packed>(); return !IsNuqStream<Packed>();
} }
// Tensor types for loading weights. // Tensor types for loading weights. Not all of these are supported weight
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 }; // types, some are only used for `Activations`.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 };
// These are used in `ModelConfig.Specifier`, hence the strings will not // These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added. // change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"sfp", "nuq", "f64"}; "nuq", "f64", "u32", "u64"};
static constexpr size_t kNumTypes = static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = { static constexpr size_t kTypeBits[] = {
@ -206,6 +207,8 @@ static constexpr size_t kTypeBits[] = {
8 * sizeof(SfpStream), 8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */, 4 /* NuqStream, actually 4.5 */,
8 * sizeof(double), 8 * sizeof(double),
8 * sizeof(uint32_t),
8 * sizeof(uint64_t),
}; };
static inline bool EnumValid(Type type) { static inline bool EnumValid(Type type) {
@ -226,6 +229,10 @@ Type TypeEnum() {
return Type::kNUQ; return Type::kNUQ;
} else if constexpr (hwy::IsSame<Packed, double>()) { } else if constexpr (hwy::IsSame<Packed, double>()) {
return Type::kF64; return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, uint32_t>()) {
return Type::kU32;
} else if constexpr (hwy::IsSame<Packed, uint64_t>()) {
return Type::kU64;
} else { } else {
HWY_DASSERT(false); HWY_DASSERT(false);
return Type::kUnknown; return Type::kUnknown;

View File

@ -21,14 +21,14 @@
#include <stdint.h> #include <stdint.h>
#include <atomic> #include <atomic>
#include <memory>
#include <vector> #include <vector>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "ops/matmul.h" // MatMulEnv
#include "ops/ops.h" // CreateInvTimescale #include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT #include "util/mat.h" // MatStorageT
#include "util/threading_context.h"
namespace gcpp { namespace gcpp {
@ -150,24 +150,28 @@ struct AttentionActivations {
struct Activations { struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
const Allocator& allocator, ThreadingContext& ctx,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]), : layer_config(config.layer_configs[0]),
x(MatFactory("x", batch_size, config.model_dim, allocator)), x(MatFactory("x", batch_size, config.model_dim, ctx.allocator)),
x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)), x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), logits(
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
config.model_dim, allocator)), config.model_dim, ctx.allocator)),
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)), C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim,
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)), ctx.allocator)),
ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)), C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim,
ctx.allocator)),
ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
attention(config, layer_config, batch_size, seq_len, allocator, attention(config, layer_config, batch_size, seq_len, ctx.allocator,
row_ptrs), row_ptrs),
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
allocator) { ctx.allocator) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers. // For MatMul outputs, precompute their row pointers.

View File

@ -19,7 +19,6 @@
#include <vector> #include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/threading_context.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -29,7 +28,7 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "util/threading_context.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
@ -234,8 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
{ {
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
// Full parallelism is helpful, kAcrossClusters is insufficient. // Full parallelism is helpful, kAcrossClusters is insufficient.
NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, HierarchicalParallelFor(
ctx.pools, func); num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools,
func);
} }
} }
@ -285,9 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Apply positional encodings for K. // Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the // Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight. // tasks are very lightweight.
env.ctx.pools.Pool(0).Run( ParallelFor(
0, kv_heads * num_interleaved, ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
[&](uint64_t task, size_t thread) HWY_ATTR { /*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads; const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
@ -308,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
if (layer.key_norm_scale.HasPtr()) { if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim, RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim,
env.ctx.profiler, thread); env.ctx.profiler, worker);
}); });
} }
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
env.ctx.profiler, thread, pos); env.ctx.profiler, worker, pos);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });

View File

@ -69,9 +69,9 @@ template <class Mat>
void ActivationBatched( void ActivationBatched(
ActivationType activation, Mat& c1, ThreadingContext& ctx, ActivationType activation, Mat& c1, ThreadingContext& ctx,
size_t cluster_idx = 0, size_t cluster_idx = 0,
ParallelismType parallelism = ParallelismType::kAcrossClusters) { ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
using T = typename Mat::T; using T = typename Mat::T;
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
// Cast to correct type so type deduction works. // Cast to correct type so type deduction works.
Activation(activation, c1.Row(task), Activation(activation, c1.Row(task),
@ -84,16 +84,16 @@ template <class Mat1, class Mat2>
HWY_NOINLINE void ActivationBatched( HWY_NOINLINE void ActivationBatched(
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
size_t cluster_idx = 0, size_t cluster_idx = 0,
ParallelismType parallelism = ParallelismType::kAcrossClusters) { ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
HWY_DASSERT(c1.SameShape(*c2)); HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) { if (c2 && c2->HasPtr()) {
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
ctx.profiler, worker); ctx.profiler, worker);
}); });
} else { // No multiplier } else { // No multiplier
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
[&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), Activation(activation, c1.Row(task),
static_cast<const typename Mat2::T*>(nullptr), static_cast<const typename Mat2::T*>(nullptr),

View File

@ -574,7 +574,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const WeightsPtrs& weights, KVCache& kv_cache, const WeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) { MatMulEnv& env, TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size, Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs); kv_cache.SeqLen(), env.ctx, env.row_ptrs);
AllQueries all_queries(prompt, pos, prefix_end, AllQueries all_queries(prompt, pos, prefix_end,
hwy::Span<KVCache>(&kv_cache, 1)); hwy::Span<KVCache>(&kv_cache, 1));
@ -592,7 +592,7 @@ void GenerateBatchT(const ModelConfig& config,
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size); runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, Activations activations(config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.ctx.allocator, all_queries[0].kv_cache.SeqLen(), env.ctx,
env.row_ptrs); env.row_ptrs);
for (size_t start = 0; start < all_queries.NumQueries(); for (size_t start = 0; start < all_queries.NumQueries();
@ -616,8 +616,8 @@ void GenerateImageTokensT(const ModelConfig& config,
const size_t num_tokens = vit_config.max_seq_len; const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim); num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, num_tokens, Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx,
env.ctx.allocator, env.row_ptrs); env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env); prefill_activations, env);

View File

@ -111,8 +111,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and // Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling // spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op. // BindB/C if there is a single package: they will be a no-op.
BindB(b_trans, sizeof(TC), env.parallel); BindB(env.ctx, b_trans, sizeof(TC));
BindC(C, env.parallel); BindC(env.ctx, C);
C.AllocateAndAttachRowPtrs(env.row_ptrs); C.AllocateAndAttachRowPtrs(env.row_ptrs);
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;
@ -160,10 +160,10 @@ void BenchAllMatMul() {
ctx.pools.PinString()); ctx.pools.PinString());
MatMulEnv env(ctx); MatMulEnv env(ctx);
for (size_t batch_size : {1, 4, 128, 512}) { for (size_t batch_size : {128, 512}) {
constexpr bool kAdd = false; constexpr bool kAdd = false;
BenchMatMul<BF16, SFP, BF16>(batch_size, 24576, 3072, kAdd, env); BenchMatMul<BF16, BF16, BF16>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP, BF16>(batch_size, 3072, 24576, kAdd, env); BenchMatMul<BF16, BF16, BF16>(batch_size, 3072, 24576, kAdd, env);
} }
PROFILER_PRINT_RESULTS(); PROFILER_PRINT_RESULTS();

View File

@ -565,46 +565,204 @@ class MMKernel {
} }
}; };
// Called on the main thread with the entire N range, or by each package with // Miscellaneous stateless helper functions.
// a static partition of N. This class contains several variants of the struct MMImpl {
// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. // Returns existing entry for the given key or -1.
// Its member variables avoid long argument lists in Do*(). static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) {
class MMPerPackage { const hwy::Span<const uint64_t> all_keys = keys.Keys();
public: // TODO: SIMD scan
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, for (size_t i = 0; i < all_keys.size(); ++i) {
size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) if (all_keys[i] == key) return static_cast<intptr_t>(i);
: args_(args), }
pkg_idx_(pkg_idx), return -1;
cluster_idx_(cluster_idx), }
range_np_(range_np),
mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.rows)),
ranges_kc_(config.RangesOfKC(A.cols)),
ranges_nc_(config.RangesOfNC(range_np)),
order_(config.Order()),
inner_tasks_(config.InnerTasks()),
line_bytes_(args.env->ctx.allocator.LineBytes()) {}
// B and maybe A are decompressed several call layers lower, but not all static size_t Worker(const MMArgs& args) {
// member functions depend on TA/TB, so pass them as an argument instead of return args.options.cluster_idx *
// templating the class. args.env->ctx.pools.MaxWorkersPerCluster();
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> }
HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
const MatPtrT<TA>& A, const MatPtrT<TB>& B, template <class Func>
RowPtrs<TC> C_rows) const { static void DispatchParallelism(ParallelismStrategy parallelism,
const Func& func) {
switch (parallelism) {
case ParallelismStrategy::kHierarchical:
return func(MMParallelHierarchical());
case ParallelismStrategy::kNone:
return func(MMParallelNone());
case ParallelismStrategy::kWithinCluster:
return func(MMParallelWithinCluster());
default:
HWY_UNREACHABLE;
}
}
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
static HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMParA par_a, const MMArgs& args) {
const IndexRange all_M(0, A.Rows());
const IndexRange all_K(0, A.Cols());
HWY_DASSERT(all_K.Num() == A_view.Cols());
const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf);
static const auto zone = args.env->ctx.profiler.AddZone("MM.DecompressA");
const auto do_range =
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args);
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
// Must be a vector multiple, or the last range before row
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
for (size_t row_a : range_M) {
const PackedSpan<const float> from =
MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
DecompressAndZeroPad(dbf, from, 0, to, cols);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
}
}
}
};
switch (par_a) {
case MMParA::kNone:
do_range(all_M, all_K, MMImpl::Worker(args));
break;
case MMParA::kK1:
case MMParA::kK2:
case MMParA::kK4: {
const size_t inner_tasks = static_cast<size_t>(par_a);
// At least one vector, otherwise DecompressAndZeroPad will add
// padding, which might overwrite neighboring tasks. Also a whole cache
// line to avoid false sharing.
const size_t multiple_K = HWY_MAX(NBF, args.line_bytes / sizeof(BF16));
DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) {
parallel.ForNP(args.env->ctx, all_K, multiple_K, inner_tasks,
args.options.cluster_idx,
[&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker);
});
});
break;
}
case MMParA::kM:
DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) {
parallel.ForRangeMC(
args.env->ctx, all_M, args.options.cluster_idx,
[&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
});
});
break;
}
}
// Autotuning wrapper for `DoDecompressA`.
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
const MMArgs& args) {
MMAutoTune<MMParA>& autotune = args.per_key->autotune_par_a[/*pkg_idx=*/0];
if (HWY_LIKELY(autotune.Best())) {
return DoDecompressA(A, A_view, *autotune.Best(), args);
}
// First call: generate candidates.
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
other};
autotune.SetCandidates(candidates);
}
const MMParA& par_a = autotune.NextConfig();
const uint64_t t0 = hwy::timer::Start();
DoDecompressA(A, A_view, par_a, args);
const uint64_t t1 =
args.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
if (HWY_UNLIKELY(args.env->print_measurement && autotune.ShouldPrint())) {
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
static_cast<double>(min_elapsed) /
hwy::platform::InvariantTicksPerSecond() * 1E6);
}
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
template <typename TA>
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
const MMArgs& args) {
if constexpr (IsBF16<TA>()) { if constexpr (IsBF16<TA>()) {
// We can use a view, regardless of columns/padding, because `LoopKC` // We can use a view, regardless of columns/padding, because `LoopKC`
// supports non-vector multiples. // supports non-vector multiples.
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); return View(A, 0, 0, A.Cols());
} else { } else {
// Always decompress. To reduce code size/compile time, we no longer // Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16. // support a separate F32 kernel; most A are already BF16.
const StridedViewBF A_view = const StridedViewBF A_view =
args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents()); args.env->storage[args.options.cluster_idx].A(/*pkg_idx=*/0,
DecompressA<MMParallelPolicyT>(A, A_view); A.Extents());
DispatchOrder(parallel_policy, A_view, B, C_rows); DecompressA(A, A_view, args);
return A_view;
} }
} }
};
// Contains several variants of the outer M/N/K loops, and calls `A2C0` which
// loops over the inner KC and MC. Member variables avoid long argument lists.
class MMState {
public:
MMState(const Extents2D A, const MMArgs& args, const MMConfig& config)
: args_(args),
range_np_(args.per_key->ranges_np.Range(/*pkg_idx=*/0)),
mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.rows)),
ranges_kc_(config.RangesOfKC(A.cols)),
ranges_nc_(config.RangesOfNC(range_np_)),
order_(config.Order()),
inner_tasks_(config.InnerTasks()) {
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
}
// Called from `MatMul` from two places: either with the next autotune config,
// or with the best config.
template <typename TB, typename TC>
HWY_NOINLINE void DispatchParallelism(const StridedViewBF A,
const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
/* Disabled due to unknown thread-safety issue:
static const auto zone =
args_.env->ctx.profiler.AddZone("MM.DispatchParallelism");
PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone);
*/
MMImpl::DispatchParallelism(
args_.options.parallelism,
[&](const auto& parallel) { DispatchOrder(parallel, A, B, C_rows); });
}
private: private:
// Compute size of per-worker storage for `kNR` row ranges of B. Stack // Compute size of per-worker storage for `kNR` row ranges of B. Stack
@ -616,40 +774,32 @@ class MMPerPackage {
// Granularity of `ForNP`. B rows produce C columns, so we // Granularity of `ForNP`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing. // want a multiple of the line size to prevent false sharing.
size_t MultipleNP(size_t sizeof_TC) const { size_t MultipleNP(size_t sizeof_TC) const {
return HWY_MAX(kNR, line_bytes_ / sizeof_TC); return HWY_MAX(kNR, args_.line_bytes / sizeof_TC);
} }
// Returns 2D subrange whose top-left is `r, c` and width is `cols`. // B is decompressed several call layers lower, but not all member functions
template <typename T> // depend on `TB`, so pass it as an argument instead of templating the class.
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c, template <typename TB, typename TC, class ParallelT>
size_t cols) { HWY_NOINLINE void DispatchOrder(const ParallelT& parallel_policy,
HWY_DASSERT(c < AB.Cols()); const StridedViewBF A, const MatPtrT<TB>& B,
HWY_DASSERT(cols <= AB.Cols() - c);
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
const StridedView<TA> A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const { RowPtrs<TC> C_rows) const {
switch (order_) { switch (order_) {
case MMOrder::kNT: case MMOrder::kNT:
return DoNT<TA, TB, TC>(parallel_policy, A, B, C_rows); return DoNT(parallel_policy, A, B, C_rows);
case MMOrder::kNT_K: case MMOrder::kNT_K:
return DoNT_K<TA, TB, TC>(parallel_policy, A, B, C_rows); return DoNT_K(parallel_policy, A, B, C_rows);
case MMOrder::kNT_MT: case MMOrder::kNT_MT:
return DoNT_MT<TA, TB, TC>(parallel_policy, A, B, C_rows); return DoNT_MT(parallel_policy, A, B, C_rows);
case MMOrder::kNT_MT_K: case MMOrder::kNT_MT_K:
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, B, C_rows); return DoNT_MT_K(parallel_policy, A, B, C_rows);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }
} }
// Single M and K ranges, parallel N. Fills all of C directly. // Single M and K ranges, parallel N. Fills all of C directly.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TB, typename TC, class ParallelT>
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT(ParallelT parallel, const StridedViewBF A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -657,14 +807,14 @@ class MMPerPackage {
const IndexRange& range_M = ranges_mc_.Range(0); const IndexRange& range_M = ranges_mc_.Range(0);
const IndexRange& range_K = ranges_kc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num(); const size_t K = range_K.Num();
const StridedView<TA> A_view = A.View(range_M.begin(), 0, K); const StridedViewBF A_view = A.View(range_M.begin(), 0, K);
const size_t B_stride = const size_t B_stride =
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
// Similar to `loop_nc` below, but here we hoisted `A_view`. // Similar to `loop_nc` below, but here we hoisted `A_view`.
MMParallelPolicyT::ForNP( parallel.ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
pkg_idx_, cluster_idx_, args_.options.cluster_idx,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_); mm_zone.MaybeEnter(worker, zone, args_);
@ -683,8 +833,8 @@ class MMPerPackage {
} }
// Single M range, parallel N, sequential K. Sets C, then accumulates. // Single M range, parallel N, sequential K. Sets C, then accumulates.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TB, typename TC, class ParallelT>
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT_K(ParallelT parallel, const StridedViewBF A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -697,11 +847,11 @@ class MMPerPackage {
const IndexRange& range_nc, const IndexRange& range_nc,
auto out_tag) HWY_ATTR { auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num(); const size_t kc = range_kc.Num();
const StridedView<TA> A_view = const StridedViewBF A_view =
A.View(range_mc.begin(), range_kc.begin(), kc); A.View(range_mc.begin(), range_kc.begin(), kc);
const StridedViewBF B_storage_view( const StridedViewBF B_storage_view(
B_storage, kc, B_storage, kc,
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); Stride(MatPadding::kOdd, kc, sizeof(BF16), args_.line_bytes));
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
@ -711,9 +861,9 @@ class MMPerPackage {
} }
}; };
MMParallelPolicyT::ForNP( parallel.ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
pkg_idx_, cluster_idx_, args_.options.cluster_idx,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_); mm_zone.MaybeEnter(worker, zone, args_);
@ -733,26 +883,26 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K. // Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C directly, in parallel.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TB, typename TC, class ParallelT>
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num(); const size_t K = range_K.Num();
const size_t B_stride = const size_t B_stride =
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
// Sequential loop over NC/MC/KC, similar to `loop_nc` below // Sequential loop over NC/MC/KC, similar to `loop_nc` below
// except for the profiler strings and `out_tag`. // except for the profiler strings and `out_tag`.
MMParallelPolicyT::ForRangesMC_NC( parallel.ForRangesMC_NC(
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR { size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_); mm_zone.MaybeEnter(worker, zone, args_);
const StridedView<TA> A_view = A.View(range_mc.begin(), 0, K); const StridedViewBF A_view = A.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const StridedViewBF B_storage_view(B_storage, K, B_stride); const StridedViewBF B_storage_view(B_storage, K, B_stride);
@ -768,14 +918,14 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K. // Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TB, typename TC, class ParallelT>
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC); HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
const size_t B_stride = const size_t B_stride =
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), line_bytes_); Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes);
// Sequential loop over NC/MC/KC, for when the M/N loops are // Sequential loop over NC/MC/KC, for when the M/N loops are
// already parallel. This is B3A2C0 in MOMMS terminology: we read // already parallel. This is B3A2C0 in MOMMS terminology: we read
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
@ -785,7 +935,7 @@ class MMPerPackage {
const IndexRange& range_nc, const IndexRange& range_nc,
auto out_tag) HWY_ATTR { auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num(); const size_t kc = range_kc.Num();
const StridedView<TA> A_view = const StridedViewBF A_view =
A.View(range_mc.begin(), range_kc.begin(), kc); A.View(range_mc.begin(), range_kc.begin(), kc);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
@ -795,8 +945,8 @@ class MMPerPackage {
C_rows); C_rows);
} }
}; // loop_nc }; // loop_nc
MMParallelPolicyT::ForRangesMC_NC( parallel.ForRangesMC_NC(
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR { size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
@ -816,106 +966,6 @@ class MMPerPackage {
}); });
} }
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
template <typename MMParallelPolicyT>
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMParA par_a) const {
const IndexRange all_M(0, A.Rows());
const IndexRange all_K(0, A.Cols());
HWY_DASSERT(all_K.Num() == A_view.Cols());
const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf);
static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA");
const auto do_range = [&](const IndexRange& range_M,
const IndexRange& range_K,
size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
// Must be a vector multiple, or the last range before row padding,
// otherwise `DecompressAndZeroPad` overwrites neighbors.
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
for (size_t row_a : range_M) {
const PackedSpan<const float> from =
MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
DecompressAndZeroPad(dbf, from, 0, to, cols);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
}
}
}
};
switch (par_a) {
case MMParA::kNone:
do_range(all_M, all_K, /*worker=*/0);
break;
case MMParA::kK1:
case MMParA::kK2:
case MMParA::kK4: {
const size_t inner_tasks = static_cast<size_t>(par_a);
// At least one vector, otherwise DecompressAndZeroPad will add
// padding, which might overwrite neighboring tasks. Also a whole cache
// line to avoid false sharing.
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
pkg_idx_, cluster_idx_,
[&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker);
});
break;
}
case MMParA::kM:
MMParallelPolicyT::ForRangeMC(
args_.env->ctx, all_M, pkg_idx_, cluster_idx_,
[&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
});
break;
}
}
// Autotuning wrapper for `DoDecompressA`.
template <typename MMParallelPolicyT>
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view) const {
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
if (HWY_LIKELY(autotune.Best())) {
return DoDecompressA<MMParallelPolicyT>(A, A_view, *autotune.Best());
}
// First call: generate candidates.
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
other};
autotune.SetCandidates(candidates);
}
const MMParA& par_a = autotune.NextConfig();
const uint64_t t0 = hwy::timer::Start();
DoDecompressA<MMParallelPolicyT>(A, A_view, par_a);
const uint64_t t1 =
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
if (HWY_UNLIKELY(args_.env->print_measurement && autotune.ShouldPrint())) {
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
static_cast<double>(min_elapsed) /
hwy::platform::InvariantTicksPerSecond() * 1E6);
}
}
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
// thanks to its large table lookups, and less so on other targets. // thanks to its large table lookups, and less so on other targets.
@ -928,7 +978,7 @@ class MMPerPackage {
// Neither A nor B require padding because `LoopKC` handles remainders. // Neither A nor B require padding because `LoopKC` handles remainders.
if constexpr (hwy::IsSame<TB, BF16>()) { if constexpr (hwy::IsSame<TB, BF16>()) {
return View(B, row_b, range_kc.begin(), range_kc.Num()); return MMImpl::View(B, row_b, range_kc.begin(), range_kc.Num());
} }
const PackedSpan<const TB> B_span = B.PaddedSpan(); const PackedSpan<const TB> B_span = B.PaddedSpan();
@ -951,8 +1001,6 @@ class MMPerPackage {
} }
const MMArgs args_; // copy for locality const MMArgs args_; // copy for locality
const size_t pkg_idx_;
const size_t cluster_idx_; // 0 for sequential and nested.
const IndexRange range_np_; const IndexRange range_np_;
// From MMConfig: // From MMConfig:
@ -962,52 +1010,7 @@ class MMPerPackage {
const IndexRangePartition ranges_nc_; const IndexRangePartition ranges_nc_;
const MMOrder order_; const MMOrder order_;
const size_t inner_tasks_; const size_t inner_tasks_;
const size_t line_bytes_; }; // MMState
}; // MMPerPackage
// Stateless, wraps member functions.
struct MMImpl {
// Returns existing entry for the given key or -1.
static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) {
const hwy::Span<const uint64_t> all_keys = keys.Keys();
// TODO: SIMD scan
for (size_t i = 0; i < all_keys.size(); ++i) {
if (all_keys[i] == key) return static_cast<intptr_t>(i);
}
return -1;
}
// Called from `MatMul` from two places: either with the next autotune config,
// or with the best config.
template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args,
const MMConfig& config, MMOptions options) {
PROFILER_ZONE("MM.DoMatMul");
const size_t pkg_idx = 0;
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
switch (options.parallelism_type) {
case ParallelismType::kNested:
HWY_DASSERT(options.cluster_idx == 0);
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMNestedParallelPolicy(), A, B, C_rows);
break;
case ParallelismType::kSequential:
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMSequentialPolicy(), A, B, C_rows);
case ParallelismType::kWithinCluster:
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
break;
default:
HWY_ABORT("Parallelism type %s not implemented.",
static_cast<int>(options.parallelism_type));
break;
}
}
};
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
// //
@ -1033,17 +1036,19 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C, MMOptions options = MMOptions()) { MatPtrT<TC>& C, MMOptions options = MMOptions()) {
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
RowPtrs<TC> C_rows =
GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]);
const Allocator& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Rows(); const size_t M = A.Rows();
const size_t K = A.Cols(); const size_t K = A.Cols();
const size_t N = B.Rows(); const size_t N = B.Rows();
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
intptr_t index = MMImpl::IndexOfKey(key, env.keys); intptr_t index = MMImpl::IndexOfKey(key, env.keys[options.cluster_idx]);
// First time we see this shape/key. // First time we see this shape/key.
if (HWY_UNLIKELY(index < 0)) { if (HWY_UNLIKELY(index < 0)) {
env.keys.Append(key, allocator); env.keys[options.cluster_idx].Append(key, allocator);
size_t max_packages = kMaxPackages; size_t max_packages = kMaxPackages;
// For low-batch, multiple sockets only help if binding is enabled. // For low-batch, multiple sockets only help if binding is enabled.
@ -1052,16 +1057,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
} }
// invalidates `MMAutoTune::Best()` // invalidates `MMAutoTune::Best()`
index = env.per_key.size(); std::vector<MMPerKey>& stored_keys = env.per_key[options.cluster_idx];
env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); index = stored_keys.size();
stored_keys.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
} }
MMPerKey& per_key = env.per_key[index]; MMPerKey& per_key = env.per_key[options.cluster_idx][index];
MMAutoTune<MMConfig>& tuner = per_key.autotune; MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(), const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
add); add, options);
if (HWY_LIKELY(tuner.Best())) { if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options); const MMState state(A.Extents(), args, *tuner.Best());
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
state.DispatchParallelism(A_view, B, C_rows);
return &per_key; return &per_key;
} }
@ -1089,7 +1097,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MMConfig& cfg = tuner.NextConfig(); const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options); MMState state(A.Extents(), args, cfg);
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
state.DispatchParallelism(A_view, B, C_rows);
const uint64_t t1 = const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) / const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /

View File

@ -405,20 +405,15 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
// Create storage per cluster. This only applies to in-cluster parallelism. // Create storage per cluster. This only applies to in-cluster parallelism.
// For nested and sequential parallelism, a single MMStorage is used. // For nested and sequential parallelism, a single MMStorage is used.
size_t num_packages = ctx.topology.NumPackages(); const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
size_t num_clusters = 0;
for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) {
num_clusters += ctx.topology.NumClusters(pkg_idx);
}
storage.reserve(num_clusters); storage.reserve(num_clusters);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
storage.push_back(MMStorage(ctx)); storage.push_back(MMStorage(ctx));
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
} }
char cpu100[100]; char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100); have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
} }
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {

View File

@ -58,147 +58,127 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
size_t N, size_t sizeof_TC, size_t nr); size_t N, size_t sizeof_TC, size_t nr);
struct MMOptions { struct MMOptions {
ParallelismType parallelism_type = ParallelismType::kNested; uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
uint8_t cluster_idx = 0; ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
}; };
struct MMSequentialPolicy { // Policy classes for parallelism, implementing some of `ParallelismStrategy`.
template <class Func>
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
const Func& func) {
func(/*pkg_idx=*/0);
}
struct MMParallelNone {
template <class Func> template <class Func>
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, size_t nx_multiple, size_t inner_tasks, size_t cluster_idx,
size_t cluster_idx, const Func& func) { const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
cluster_idx * ctx.pools.MaxWorkersPerCluster(); func(range_np, worker);
func(range_np, base_idx);
} }
template <class Func> template <class Func>
static void ForRangesMC_NC(ThreadingContext& ctx, void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, const IndexRangePartition& ranges_nc, size_t cluster_idx,
size_t pkg_idx, size_t cluster_idx, const Func& func) const {
const Func& func) { const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
cluster_idx * ctx.pools.MaxWorkersPerCluster();
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
const IndexRange range_mc = ranges_mc.Range(i); const IndexRange range_mc = ranges_mc.Range(i);
for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) { for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) {
const IndexRange range_nc = ranges_nc.Range(j); const IndexRange range_nc = ranges_nc.Range(j);
func(range_mc, range_nc, base_idx); func(range_mc, range_nc, worker);
} }
} }
} }
template <class Func> template <class Func>
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t pkg_idx, size_t cluster_idx, const Func& func) { size_t cluster_idx, const Func& func) const {
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
cluster_idx * ctx.pools.MaxWorkersPerCluster();
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
func(row_a, base_idx); func(row_a, worker);
} }
} }
}; };
struct MMClusterParallelPolicy { struct MMParallelWithinCluster {
template <class Func> template <class Func>
static void ForPkg(ThreadingContext& ctx, const size_t max_packages, void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
const Func& func) { size_t nx_multiple, size_t inner_tasks, size_t cluster_idx,
func(/*pkg_idx=*/0); const Func& func) const {
}
template <class Func>
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
size_t cluster_idx, const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const IndexRangePartition worker_ranges = StaticPartition( const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
ParallelizeOneRange(worker_ranges, cluster, ParallelizeOneRange(worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t worker) { [&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker); func(worker_range, base + worker);
}); });
} }
template <class Func> template <class Func>
static void ForRangesMC_NC(ThreadingContext& ctx, void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, const IndexRangePartition& ranges_nc, size_t cluster_idx,
size_t pkg_idx, size_t cluster_idx, const Func& func) const {
const Func& func) { const size_t pkg_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
// Low-batch: avoid Divide/Remainder. // Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
ParallelizeOneRange(ranges_nc, cluster, ParallelizeOneRange(ranges_nc, cluster,
[&](const IndexRange& range_nc, size_t thread) { [&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, thread); func(ranges_mc.Range(0), range_nc, base + worker);
}); });
} else { } else {
ParallelizeTwoRanges( ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster, ranges_mc, ranges_nc, cluster,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t thread) { func(range_mc, range_nc, thread); }); size_t worker) { func(range_mc, range_nc, base + worker); });
} }
} }
template <class Func> template <class Func>
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t pkg_idx, size_t cluster_idx, const Func& func) { size_t cluster_idx, const Func& func) const {
const size_t pkg_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
cluster.Run(range_mc.begin(), range_mc.end(), const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
[&](uint64_t row_a, size_t thread) { func(row_a, thread); });
cluster.Run(
range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
} }
}; };
struct MMNestedParallelPolicy { struct MMParallelHierarchical {
template <class Func>
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
const Func& func) {
if constexpr (kMaxPackages > 1) {
ctx.pools.AllPackages().Run(
0, HWY_MIN(max_packages, ctx.pools.NumPackages()),
[&](uint64_t task, size_t pkg_idx) {
HWY_DASSERT(task == pkg_idx);
(void)task;
func(pkg_idx);
});
} else {
func(/*pkg_idx=*/0);
}
}
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
// `cluster_idx` is not used here as all clusters within a package are used.
template <class Func> template <class Func>
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, size_t nx_multiple, size_t inner_tasks,
size_t /*cluster_idx*/, const Func& func) { HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); HWY_DASSERT(caller_cluster_idx == 0);
// Single cluster: parallel-for over static partition of `range_np`. // Single cluster: parallel-for over static partition of `range_np`.
const size_t pkg_idx = 0;
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) { if (num_clusters == 1) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0); const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const IndexRangePartition worker_ranges = StaticPartition( const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
return ParallelizeOneRange( return ParallelizeOneRange(
worker_ranges, cluster, worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t thread) { [&](const IndexRange& worker_range, size_t worker) {
func(worker_range, pkg_base + thread); func(worker_range, worker);
}); });
} }
@ -210,28 +190,29 @@ struct MMNestedParallelPolicy {
[&](const IndexRange& nx_range, const size_t cluster_idx) { [&](const IndexRange& nx_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base = const size_t cluster_base =
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); cluster_idx * ctx.pools.MaxWorkersPerCluster();
// Parallel-for over sub-ranges of `cluster_range` within the cluster. // Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition( const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
ParallelizeOneRange( ParallelizeOneRange(
worker_ranges, cluster, worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t thread) { [&](const IndexRange& worker_range, size_t worker) {
func(worker_range, cluster_base + thread); func(worker_range, cluster_base + worker);
}); });
}); });
} }
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
// rows). Calls `func(range_mc, range_nc, worker)`. // rows). Calls `func(range_mc, range_nc, worker)`.
// `cluster_idx` is not used here as all clusters within a package are used.
template <class Func> template <class Func>
static void ForRangesMC_NC(ThreadingContext& ctx, void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, const IndexRangePartition& ranges_nc,
size_t pkg_idx, size_t /*cluster_idx*/, HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) { const Func& func) const {
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); const size_t pkg_idx = 0;
HWY_DASSERT(caller_cluster_idx == 0);
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
// `all_clusters` is a pool with one worker per cluster in a package. // `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
@ -243,16 +224,14 @@ struct MMNestedParallelPolicy {
// Low-batch: avoid Divide/Remainder. // Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange( return ParallelizeOneRange(
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t thread) { ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, pkg_base + thread); func(ranges_mc.Range(0), range_nc, worker);
}); });
} else { } else {
return ParallelizeTwoRanges( return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster, ranges_mc, ranges_nc, cluster,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t thread) { size_t worker) { func(range_mc, range_nc, worker); });
func(range_mc, range_nc, pkg_base + thread);
});
} }
} }
@ -262,25 +241,23 @@ struct MMNestedParallelPolicy {
ranges_nc, all_clusters, ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) { [&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base = const size_t cluster_base =
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); cluster_idx * ctx.pools.MaxWorkersPerCluster();
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
ParallelizeOneRange(ranges_mc, cluster, ParallelizeOneRange(ranges_mc, cluster,
[&](const IndexRange& range_mc, size_t thread) { [&](const IndexRange& range_mc, size_t worker) {
func(range_mc, range_nc, cluster_base + thread); func(range_mc, range_nc, cluster_base + worker);
}); });
}); });
} }
// Calls `func(row_a, worker)` in parallel. // Calls `func(row_a, worker)` in parallel.
// `cluster_idx` is not used here as all clusters within a package are used.
template <class Func> template <class Func>
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t pkg_idx, size_t /*cluster_idx*/, size_t caller_cluster_idx, const Func& func) const {
const Func& func) { HierarchicalParallelFor(range_mc.Num(), ctx.pools,
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); [&](size_t task, size_t worker) {
ctx.pools.Pool(pkg_idx).Run( func(range_mc.begin() + task, worker);
range_mc.begin(), range_mc.end(), });
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
} }
}; };
@ -340,16 +317,12 @@ class MMStorage {
// Internally threaded; must not be called concurrently with the same // Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`). // `ThreadingContext` (used via `parallel`).
MMStorage(ThreadingContext& ctx) { MMStorage(ThreadingContext& ctx) {
// Per-package allocation so each can decompress A into its own copy.
// Must be padded, see `DoDecompressA`.
// Default to nested parallel policy.
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
Allocator& allocator = ctx.allocator; Allocator& allocator = ctx.allocator;
const size_t pkg_idx = 0;
// 0.5 GiB per package. // 0.5 GiB per package. Must be padded, see `DoDecompressA`.
pkg_A_[pkg_idx].reset( pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
new MatStorageT<BF16>("pkg_A", Extents2D(kMaxBatchSize, kMaxK), "pkg_A", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd));
allocator, MatPadding::kOdd));
if (allocator.ShouldBind()) { if (allocator.ShouldBind()) {
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
@ -360,7 +333,6 @@ class MMStorage {
HWY_WARN("Failed to bind memory for package %zu", pkg_idx); HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
} }
} }
});
} }
// Returns per-package matrix view. Converting A=F32 to BF16 up-front is // Returns per-package matrix view. Converting A=F32 to BF16 up-front is
@ -735,16 +707,18 @@ struct MatMulEnv {
bool print_best = false; bool print_best = false;
std::vector<MMStorage> storage; std::vector<MMStorage> storage;
MMKeys keys; MMKeys keys[kMaxClusters];
std::vector<MMPerKey> per_key; std::vector<MMPerKey> per_key[kMaxClusters];
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
// writes to differing KV positions per query / output row. // writes to differing KV positions per query / output row.
// The first entry is sufficient for any C argument, but also potentially // The first `num_clusters` entries are sufficient for any C argument, and
// overwritten by each MatMul. Subsequent entries are precomputed for tensors // must be indexed by `options.cluster_idx`. Note that they are potentially
// and not overwritten. Per-tensor allocations make it likelier that asan // overwritten by each `MatMul`. Subsequent entries are for specific tensors
// detects bugs such as use after free, overrun, and dangling references. // and only written once by their allocator. A per-tensor allocation makes it
// likelier that asan detects bugs such as use after free, overrun, and
// dangling references.
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs; std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
}; };
@ -752,14 +726,21 @@ struct MatMulEnv {
// Reduces register pressure compared to individual values/references. // Reduces register pressure compared to individual values/references.
struct MMArgs { struct MMArgs {
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
const float* HWY_RESTRICT add) const float* HWY_RESTRICT add, MMOptions options)
: env(&env), per_key(&per_key), scale(scale), add(add) {} : env(&env),
per_key(&per_key),
scale(scale),
add(add),
options(options),
line_bytes(env.ctx.allocator.LineBytes()) {}
MatMulEnv* env; MatMulEnv* env;
MMPerKey* per_key; MMPerKey* per_key;
double scale; double scale;
const float* HWY_RESTRICT add; const float* HWY_RESTRICT add;
MMOptions options;
size_t line_bytes;
}; };
// Wrapper over hwy::Zone that is only enabled when autotuning finished. // Wrapper over hwy::Zone that is only enabled when autotuning finished.

View File

@ -501,11 +501,11 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
HWY_DASSERT(activations.SameShape(out)); HWY_DASSERT(activations.SameShape(out));
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor( ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools,
cluster_idx, [&](uint64_t token_idx, size_t worker) { cluster_idx, [&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
out.Row(token_idx), activations.Cols(), ctx.profiler, worker); out.Row(token_idx), activations.Cols(), ctx.profiler,
worker);
}); });
}); });
} }
@ -517,11 +517,11 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
HWY_DASSERT(weights.Cols() == inout.Cols()); HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor( ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx,
[&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx), RMSNormInplace(weights_t->PackedScale1(),
inout.Cols(), ctx.profiler, worker); inout.Row(token_idx), inout.Cols(),
ctx.profiler, worker);
}); });
}); });
} }
@ -548,8 +548,8 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
ThreadingContext& ctx, ThreadingContext& ctx,
size_t cluster_idx = 0) { size_t cluster_idx = 0) {
HWY_DASSERT(out.SameShape(x)); HWY_DASSERT(out.SameShape(x));
ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools, ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
cluster_idx, [&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
ctx.profiler, worker); ctx.profiler, worker);
}); });
@ -782,8 +782,8 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos, const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
ThreadingContext& ctx, size_t cluster_idx = 0) { ThreadingContext& ctx, size_t cluster_idx = 0) {
if (cap == 0.0f) return; if (cap == 0.0f) return;
ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools, ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
cluster_idx, [&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) { if (non_eos.Get(task)) {
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
worker); worker);

View File

@ -35,6 +35,8 @@ namespace gcpp {
// is disabled if this is 1. // is disabled if this is 1.
HWY_INLINE_VAR constexpr size_t kMaxPackages = 1; HWY_INLINE_VAR constexpr size_t kMaxPackages = 1;
HWY_INLINE_VAR constexpr size_t kMaxClusters = 128; // TODO: shrink
// TODO: extend to 16k after updating non_eos. // TODO: extend to 16k after updating non_eos.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;

View File

@ -326,7 +326,8 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster. // over clusters of ONE package, then within each cluster.
template <class Func> template <class Func>
void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools,
const Func& func) {
// Even if there are multiple packages, we only use the first. // Even if there are multiple packages, we only use the first.
const size_t pkg_idx = 0; const size_t pkg_idx = 0;
@ -356,59 +357,6 @@ void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
}); });
} }
// Which pool(s) to use for parallelizing:
enum class ParallelismType : uint8_t {
// None: single-threaded loop on the calling thread.
kSequential,
// One thread per cluster within the first package; or one per core if there
// is only one cluster. Use for few or lightweight tasks, or to maximize
// memory bandwidth availability.
kAcrossClusters,
// All cores within the cluster identified by `cluster_idx`. Use if already
// within a `kAcrossClusters` parallel-for, or if latency is more important
// than memory bandwidth.
kWithinCluster,
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
// utilizes all cores, but has higher fork-join overhead (two barriers); use
// if there are many or heavy tasks.
kNested,
};
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
// number/type of workers determined by `parallelism`. `cluster_idx` is only
// used if `parallelism == kWithinCluster`.
template <class Func>
void ParallelFor(ParallelismType parallelism, size_t num_tasks,
NestedPools& pools, size_t cluster_idx, const Func& func) {
if (cluster_idx != 0) {
// If already running across clusters, must not use across-cluster modes.
HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters &&
parallelism != ParallelismType::kNested);
}
const size_t pkg_idx = 0;
switch (parallelism) {
case ParallelismType::kSequential:
for (size_t task = 0; task < num_tasks; ++task) {
func(task, /*worker=*/0);
}
return;
case ParallelismType::kAcrossClusters:
return pools.Pool(pkg_idx).Run(
0, num_tasks,
[&](uint64_t task, size_t worker) { func(task, worker); });
case ParallelismType::kWithinCluster:
return pools.Cluster(pkg_idx, cluster_idx)
.Run(0, num_tasks,
[&](uint64_t task, size_t worker) { func(task, worker); });
case ParallelismType::kNested:
return NestedParallelFor(num_tasks, pools, func);
}
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -116,6 +116,95 @@ struct ThreadingContext {
NestedPools pools; NestedPools pools;
}; };
// Describes the strategy for distributing parallel work across cores.
enum class ParallelismStrategy : uint8_t {
// Execute using a single-threaded loop on the calling thread. The `worker`
// index passed to the user's `Func` is unique across clusters.
kNone,
// One thread per cluster within the first package. The `worker` index passed
// to the user's `Func` is a `cluster_idx <= NumClusters()`. Some CPUs may
// only have a single cluster, hence `Func` should also contain a nested
// `ParallelFor` with `kWithinCluster`.
kAcrossClusters,
// All cores within the cluster identified by `cluster_idx`. The `worker`
// index passed to the user's `Func` is unique across clusters. Choose this
// strategy if already within a `ParallelFor` call with `kAcrossClusters`,
// or latency is more important than memory bandwidth.
kWithinCluster,
// Equivalent to `kAcrossClusters` if there are multiple clusters, otherwise
// `kWithinCluster`. Use for few or lightweight tasks (this only uses a
// single pool and barrier), or to maximize memory bandwidth availability.
kFlat,
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
// utilizes all cores, but has higher fork-join overhead (two barriers); use
// if there are many or heavy tasks.
kHierarchical,
};
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
// number/type of workers determined by `parallelism`. `cluster_idx` is for
// `parallelism == kWithinCluster`, and should be 0 if unknown.
template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, const Func& func) {
HWY_DASSERT(ctx.topology.NumPackages() == 1);
const size_t pkg_idx = 0;
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters(pkg_idx));
if (cluster_idx != 0) {
// If already running across clusters, only use within-cluster modes.
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
parallelism == ParallelismStrategy::kWithinCluster);
}
switch (parallelism) {
case ParallelismStrategy::kNone: {
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
for (size_t task = 0; task < num_tasks; ++task) {
func(task, worker);
}
return;
}
case ParallelismStrategy::kAcrossClusters:
return ctx.pools.AllClusters(pkg_idx).Run(
0, num_tasks,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster: {
// Ensure the worker argument is unique across clusters, because it is
// used for TLS indexing for example in profiler.h.
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
return ctx.pools.Cluster(pkg_idx, cluster_idx)
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
func(task, base + worker);
});
}
case ParallelismStrategy::kFlat: {
// Check for single cluster; if not, we must compute `cluster_base` for
// consistent and non-overlapping worker indices.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
return ctx.pools.Cluster(pkg_idx, cluster_idx)
.Run(0, num_tasks,
[&](uint64_t task, size_t worker) { func(task, worker); });
}
return ctx.pools.AllClusters(pkg_idx).Run(
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
const size_t worker =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
func(task, worker);
});
}
case ParallelismStrategy::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx.pools, func);
}
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_CONTEXT_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_CONTEXT_H_