mirror of https://github.com/google/gemma.cpp.git
MatMul simplification, threading strategy improvements
remove MatMul f32 special case (smaller code), types: Add u32/u64 for use by Activations move renamed ParallelismStrategy to threading_context so can pass ctx ensure worker index is unique across clusters matmul.h: const member functions for renamed policy classes (easier to call) PiperOrigin-RevId: 802848086
This commit is contained in:
parent
74ffe079c4
commit
7263ab8445
|
|
@ -191,12 +191,13 @@ constexpr bool SupportsPointerArithmetic() {
|
||||||
return !IsNuqStream<Packed>();
|
return !IsNuqStream<Packed>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tensor types for loading weights.
|
// Tensor types for loading weights. Not all of these are supported weight
|
||||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64 };
|
// types, some are only used for `Activations`.
|
||||||
|
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64 };
|
||||||
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
// These are used in `ModelConfig.Specifier`, hence the strings will not
|
||||||
// change, though new ones may be added.
|
// change, though new ones may be added.
|
||||||
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16",
|
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||||
"sfp", "nuq", "f64"};
|
"nuq", "f64", "u32", "u64"};
|
||||||
static constexpr size_t kNumTypes =
|
static constexpr size_t kNumTypes =
|
||||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||||
static constexpr size_t kTypeBits[] = {
|
static constexpr size_t kTypeBits[] = {
|
||||||
|
|
@ -206,6 +207,8 @@ static constexpr size_t kTypeBits[] = {
|
||||||
8 * sizeof(SfpStream),
|
8 * sizeof(SfpStream),
|
||||||
4 /* NuqStream, actually 4.5 */,
|
4 /* NuqStream, actually 4.5 */,
|
||||||
8 * sizeof(double),
|
8 * sizeof(double),
|
||||||
|
8 * sizeof(uint32_t),
|
||||||
|
8 * sizeof(uint64_t),
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool EnumValid(Type type) {
|
static inline bool EnumValid(Type type) {
|
||||||
|
|
@ -226,6 +229,10 @@ Type TypeEnum() {
|
||||||
return Type::kNUQ;
|
return Type::kNUQ;
|
||||||
} else if constexpr (hwy::IsSame<Packed, double>()) {
|
} else if constexpr (hwy::IsSame<Packed, double>()) {
|
||||||
return Type::kF64;
|
return Type::kF64;
|
||||||
|
} else if constexpr (hwy::IsSame<Packed, uint32_t>()) {
|
||||||
|
return Type::kU32;
|
||||||
|
} else if constexpr (hwy::IsSame<Packed, uint64_t>()) {
|
||||||
|
return Type::kU64;
|
||||||
} else {
|
} else {
|
||||||
HWY_DASSERT(false);
|
HWY_DASSERT(false);
|
||||||
return Type::kUnknown;
|
return Type::kUnknown;
|
||||||
|
|
|
||||||
|
|
@ -21,14 +21,14 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
#include "ops/matmul.h" // MatMulEnv
|
#include "ops/ops.h" // CreateInvTimescale
|
||||||
#include "ops/ops.h" // CreateInvTimescale
|
#include "util/basics.h" // BF16
|
||||||
#include "util/allocator.h" // Allocator
|
#include "util/mat.h" // MatStorageT
|
||||||
#include "util/basics.h" // BF16
|
#include "util/threading_context.h"
|
||||||
#include "util/mat.h" // MatStorageT
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -150,24 +150,28 @@ struct AttentionActivations {
|
||||||
|
|
||||||
struct Activations {
|
struct Activations {
|
||||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||||
const Allocator& allocator,
|
ThreadingContext& ctx,
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: layer_config(config.layer_configs[0]),
|
: layer_config(config.layer_configs[0]),
|
||||||
|
|
||||||
x(MatFactory("x", batch_size, config.model_dim, allocator)),
|
x(MatFactory("x", batch_size, config.model_dim, ctx.allocator)),
|
||||||
x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)),
|
x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
|
||||||
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
|
logits(
|
||||||
|
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
|
||||||
|
|
||||||
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
||||||
config.model_dim, allocator)),
|
config.model_dim, ctx.allocator)),
|
||||||
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)),
|
C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim,
|
||||||
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)),
|
ctx.allocator)),
|
||||||
ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)),
|
C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim,
|
||||||
|
ctx.allocator)),
|
||||||
|
ffw_out(
|
||||||
|
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||||
|
|
||||||
attention(config, layer_config, batch_size, seq_len, allocator,
|
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
|
||||||
row_ptrs),
|
row_ptrs),
|
||||||
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
|
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
|
||||||
allocator) {
|
ctx.allocator) {
|
||||||
HWY_ASSERT(batch_size != 0);
|
HWY_ASSERT(batch_size != 0);
|
||||||
|
|
||||||
// For MatMul outputs, precompute their row pointers.
|
// For MatMul outputs, precompute their row pointers.
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
#include "util/threading_context.h"
|
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
@ -29,7 +28,7 @@
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
|
@ -234,8 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||||
// Full parallelism is helpful, kAcrossClusters is insufficient.
|
// Full parallelism is helpful, kAcrossClusters is insufficient.
|
||||||
NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
HierarchicalParallelFor(
|
||||||
ctx.pools, func);
|
num_tokens * div_qbatch.GetDivisor() * layer_config.heads, ctx.pools,
|
||||||
|
func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -285,9 +285,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
// Apply positional encodings for K.
|
// Apply positional encodings for K.
|
||||||
// Note that 2D parallelism is not worth the fork/join overhead because the
|
// Note that 2D parallelism is not worth the fork/join overhead because the
|
||||||
// tasks are very lightweight.
|
// tasks are very lightweight.
|
||||||
env.ctx.pools.Pool(0).Run(
|
ParallelFor(
|
||||||
0, kv_heads * num_interleaved,
|
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
|
||||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
/*cluster_idx=*/0, [&](size_t task, size_t worker) HWY_ATTR {
|
||||||
const size_t head = task % kv_heads;
|
const size_t head = task % kv_heads;
|
||||||
const size_t interleaved_idx = task / kv_heads;
|
const size_t interleaved_idx = task / kv_heads;
|
||||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
|
|
@ -308,12 +308,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
if (layer.key_norm_scale.HasPtr()) {
|
if (layer.key_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim,
|
RMSNormInplace(weights_t->PackedScale1(), kv_f32, qkv_dim,
|
||||||
env.ctx.profiler, thread);
|
env.ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
PositionalEncodingQK(kv_f32, layer_idx, layer, activations,
|
||||||
env.ctx.profiler, thread, pos);
|
env.ctx.profiler, worker, pos);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -69,9 +69,9 @@ template <class Mat>
|
||||||
void ActivationBatched(
|
void ActivationBatched(
|
||||||
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||||
size_t cluster_idx = 0,
|
size_t cluster_idx = 0,
|
||||||
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
|
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||||
[&](uint64_t task, size_t worker) {
|
[&](uint64_t task, size_t worker) {
|
||||||
// Cast to correct type so type deduction works.
|
// Cast to correct type so type deduction works.
|
||||||
Activation(activation, c1.Row(task),
|
Activation(activation, c1.Row(task),
|
||||||
|
|
@ -84,16 +84,16 @@ template <class Mat1, class Mat2>
|
||||||
HWY_NOINLINE void ActivationBatched(
|
HWY_NOINLINE void ActivationBatched(
|
||||||
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||||
size_t cluster_idx = 0,
|
size_t cluster_idx = 0,
|
||||||
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
|
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
|
||||||
HWY_DASSERT(c1.SameShape(*c2));
|
HWY_DASSERT(c1.SameShape(*c2));
|
||||||
if (c2 && c2->HasPtr()) {
|
if (c2 && c2->HasPtr()) {
|
||||||
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||||
[&](uint64_t task, size_t worker) {
|
[&](uint64_t task, size_t worker) {
|
||||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||||
ctx.profiler, worker);
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
} else { // No multiplier
|
} else { // No multiplier
|
||||||
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
|
||||||
[&](uint64_t task, size_t worker) {
|
[&](uint64_t task, size_t worker) {
|
||||||
Activation(activation, c1.Row(task),
|
Activation(activation, c1.Row(task),
|
||||||
static_cast<const typename Mat2::T*>(nullptr),
|
static_cast<const typename Mat2::T*>(nullptr),
|
||||||
|
|
|
||||||
|
|
@ -574,7 +574,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const WeightsPtrs& weights, KVCache& kv_cache,
|
const WeightsPtrs& weights, KVCache& kv_cache,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||||
kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs);
|
kv_cache.SeqLen(), env.ctx, env.row_ptrs);
|
||||||
|
|
||||||
AllQueries all_queries(prompt, pos, prefix_end,
|
AllQueries all_queries(prompt, pos, prefix_end,
|
||||||
hwy::Span<KVCache>(&kv_cache, 1));
|
hwy::Span<KVCache>(&kv_cache, 1));
|
||||||
|
|
@ -592,7 +592,7 @@ void GenerateBatchT(const ModelConfig& config,
|
||||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||||
runtime_config.prefill_tbatch_size);
|
runtime_config.prefill_tbatch_size);
|
||||||
Activations activations(config, max_batch_size,
|
Activations activations(config, max_batch_size,
|
||||||
all_queries[0].kv_cache.SeqLen(), env.ctx.allocator,
|
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
||||||
env.row_ptrs);
|
env.row_ptrs);
|
||||||
|
|
||||||
for (size_t start = 0; start < all_queries.NumQueries();
|
for (size_t start = 0; start < all_queries.NumQueries();
|
||||||
|
|
@ -616,8 +616,8 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const size_t num_tokens = vit_config.max_seq_len;
|
const size_t num_tokens = vit_config.max_seq_len;
|
||||||
prefill_runtime_config.prefill_tbatch_size =
|
prefill_runtime_config.prefill_tbatch_size =
|
||||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||||
Activations prefill_activations(vit_config, num_tokens, num_tokens,
|
Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx,
|
||||||
env.ctx.allocator, env.row_ptrs);
|
env.row_ptrs);
|
||||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||||
prefill_activations, env);
|
prefill_activations, env);
|
||||||
|
|
|
||||||
|
|
@ -111,8 +111,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
// Ensure usage conditions are set before autotuning. Both binding and
|
// Ensure usage conditions are set before autotuning. Both binding and
|
||||||
// spinning may materially affect the choice of config. No harm in calling
|
// spinning may materially affect the choice of config. No harm in calling
|
||||||
// BindB/C if there is a single package: they will be a no-op.
|
// BindB/C if there is a single package: they will be a no-op.
|
||||||
BindB(b_trans, sizeof(TC), env.parallel);
|
BindB(env.ctx, b_trans, sizeof(TC));
|
||||||
BindC(C, env.parallel);
|
BindC(env.ctx, C);
|
||||||
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
|
|
@ -160,10 +160,10 @@ void BenchAllMatMul() {
|
||||||
ctx.pools.PinString());
|
ctx.pools.PinString());
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
|
|
||||||
for (size_t batch_size : {1, 4, 128, 512}) {
|
for (size_t batch_size : {128, 512}) {
|
||||||
constexpr bool kAdd = false;
|
constexpr bool kAdd = false;
|
||||||
BenchMatMul<BF16, SFP, BF16>(batch_size, 24576, 3072, kAdd, env);
|
BenchMatMul<BF16, BF16, BF16>(batch_size, 24576, 3072, kAdd, env);
|
||||||
BenchMatMul<BF16, SFP, BF16>(batch_size, 3072, 24576, kAdd, env);
|
BenchMatMul<BF16, BF16, BF16>(batch_size, 3072, 24576, kAdd, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
PROFILER_PRINT_RESULTS();
|
||||||
|
|
|
||||||
472
ops/matmul-inl.h
472
ops/matmul-inl.h
|
|
@ -565,46 +565,204 @@ class MMKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Called on the main thread with the entire N range, or by each package with
|
// Miscellaneous stateless helper functions.
|
||||||
// a static partition of N. This class contains several variants of the
|
struct MMImpl {
|
||||||
// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC.
|
// Returns existing entry for the given key or -1.
|
||||||
// Its member variables avoid long argument lists in Do*().
|
static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) {
|
||||||
class MMPerPackage {
|
const hwy::Span<const uint64_t> all_keys = keys.Keys();
|
||||||
public:
|
// TODO: SIMD scan
|
||||||
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
|
for (size_t i = 0; i < all_keys.size(); ++i) {
|
||||||
size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np)
|
if (all_keys[i] == key) return static_cast<intptr_t>(i);
|
||||||
: args_(args),
|
}
|
||||||
pkg_idx_(pkg_idx),
|
return -1;
|
||||||
cluster_idx_(cluster_idx),
|
}
|
||||||
range_np_(range_np),
|
|
||||||
mr_(config.MR()),
|
|
||||||
ranges_mc_(config.RangesOfMC(A.rows)),
|
|
||||||
ranges_kc_(config.RangesOfKC(A.cols)),
|
|
||||||
ranges_nc_(config.RangesOfNC(range_np)),
|
|
||||||
order_(config.Order()),
|
|
||||||
inner_tasks_(config.InnerTasks()),
|
|
||||||
line_bytes_(args.env->ctx.allocator.LineBytes()) {}
|
|
||||||
|
|
||||||
// B and maybe A are decompressed several call layers lower, but not all
|
static size_t Worker(const MMArgs& args) {
|
||||||
// member functions depend on TA/TB, so pass them as an argument instead of
|
return args.options.cluster_idx *
|
||||||
// templating the class.
|
args.env->ctx.pools.MaxWorkersPerCluster();
|
||||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
}
|
||||||
HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
|
|
||||||
const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
template <class Func>
|
||||||
RowPtrs<TC> C_rows) const {
|
static void DispatchParallelism(ParallelismStrategy parallelism,
|
||||||
|
const Func& func) {
|
||||||
|
switch (parallelism) {
|
||||||
|
case ParallelismStrategy::kHierarchical:
|
||||||
|
return func(MMParallelHierarchical());
|
||||||
|
case ParallelismStrategy::kNone:
|
||||||
|
return func(MMParallelNone());
|
||||||
|
case ParallelismStrategy::kWithinCluster:
|
||||||
|
return func(MMParallelWithinCluster());
|
||||||
|
default:
|
||||||
|
HWY_UNREACHABLE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||||
|
static HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
||||||
|
const StridedViewBF A_view,
|
||||||
|
MMParA par_a, const MMArgs& args) {
|
||||||
|
const IndexRange all_M(0, A.Rows());
|
||||||
|
const IndexRange all_K(0, A.Cols());
|
||||||
|
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
||||||
|
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
||||||
|
static const auto zone = args.env->ctx.profiler.AddZone("MM.DecompressA");
|
||||||
|
|
||||||
|
const auto do_range =
|
||||||
|
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
||||||
|
HWY_ATTR {
|
||||||
|
MMZone mm_zone;
|
||||||
|
mm_zone.MaybeEnter(worker, zone, args);
|
||||||
|
|
||||||
|
const size_t col0 = range_K.begin();
|
||||||
|
const size_t cols = range_K.Num();
|
||||||
|
// Must be a vector multiple, or the last range before row
|
||||||
|
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
|
||||||
|
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
||||||
|
for (size_t row_a : range_M) {
|
||||||
|
const PackedSpan<const float> from =
|
||||||
|
MakeSpan(A.Row(row_a) + col0, cols);
|
||||||
|
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
|
||||||
|
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
||||||
|
// Verify that we zero-padded.
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
|
||||||
|
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (par_a) {
|
||||||
|
case MMParA::kNone:
|
||||||
|
do_range(all_M, all_K, MMImpl::Worker(args));
|
||||||
|
break;
|
||||||
|
|
||||||
|
case MMParA::kK1:
|
||||||
|
case MMParA::kK2:
|
||||||
|
case MMParA::kK4: {
|
||||||
|
const size_t inner_tasks = static_cast<size_t>(par_a);
|
||||||
|
// At least one vector, otherwise DecompressAndZeroPad will add
|
||||||
|
// padding, which might overwrite neighboring tasks. Also a whole cache
|
||||||
|
// line to avoid false sharing.
|
||||||
|
const size_t multiple_K = HWY_MAX(NBF, args.line_bytes / sizeof(BF16));
|
||||||
|
|
||||||
|
DispatchParallelism(
|
||||||
|
args.options.parallelism, [&](const auto& parallel) {
|
||||||
|
parallel.ForNP(args.env->ctx, all_K, multiple_K, inner_tasks,
|
||||||
|
args.options.cluster_idx,
|
||||||
|
[&](const IndexRange& range_K, size_t worker) {
|
||||||
|
do_range(all_M, range_K, worker);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case MMParA::kM:
|
||||||
|
DispatchParallelism(
|
||||||
|
args.options.parallelism, [&](const auto& parallel) {
|
||||||
|
parallel.ForRangeMC(
|
||||||
|
args.env->ctx, all_M, args.options.cluster_idx,
|
||||||
|
[&](size_t row_a, size_t worker) {
|
||||||
|
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Autotuning wrapper for `DoDecompressA`.
|
||||||
|
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
||||||
|
const StridedViewBF A_view,
|
||||||
|
const MMArgs& args) {
|
||||||
|
MMAutoTune<MMParA>& autotune = args.per_key->autotune_par_a[/*pkg_idx=*/0];
|
||||||
|
|
||||||
|
if (HWY_LIKELY(autotune.Best())) {
|
||||||
|
return DoDecompressA(A, A_view, *autotune.Best(), args);
|
||||||
|
}
|
||||||
|
|
||||||
|
// First call: generate candidates.
|
||||||
|
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
||||||
|
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
|
||||||
|
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
|
||||||
|
other};
|
||||||
|
autotune.SetCandidates(candidates);
|
||||||
|
}
|
||||||
|
|
||||||
|
const MMParA& par_a = autotune.NextConfig();
|
||||||
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
|
DoDecompressA(A, A_view, par_a, args);
|
||||||
|
const uint64_t t1 =
|
||||||
|
args.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
|
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||||
|
if (HWY_UNLIKELY(args.env->print_measurement && autotune.ShouldPrint())) {
|
||||||
|
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
||||||
|
static_cast<double>(min_elapsed) /
|
||||||
|
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
template <typename T>
|
||||||
|
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
||||||
|
size_t cols) {
|
||||||
|
HWY_DASSERT(c < AB.Cols());
|
||||||
|
HWY_DASSERT(cols <= AB.Cols() - c);
|
||||||
|
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TA>
|
||||||
|
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
|
||||||
|
const MMArgs& args) {
|
||||||
if constexpr (IsBF16<TA>()) {
|
if constexpr (IsBF16<TA>()) {
|
||||||
// We can use a view, regardless of columns/padding, because `LoopKC`
|
// We can use a view, regardless of columns/padding, because `LoopKC`
|
||||||
// supports non-vector multiples.
|
// supports non-vector multiples.
|
||||||
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
|
return View(A, 0, 0, A.Cols());
|
||||||
} else {
|
} else {
|
||||||
// Always decompress. To reduce code size/compile time, we no longer
|
// Always decompress. To reduce code size/compile time, we no longer
|
||||||
// support a separate F32 kernel; most A are already BF16.
|
// support a separate F32 kernel; most A are already BF16.
|
||||||
const StridedViewBF A_view =
|
const StridedViewBF A_view =
|
||||||
args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents());
|
args.env->storage[args.options.cluster_idx].A(/*pkg_idx=*/0,
|
||||||
DecompressA<MMParallelPolicyT>(A, A_view);
|
A.Extents());
|
||||||
DispatchOrder(parallel_policy, A_view, B, C_rows);
|
DecompressA(A, A_view, args);
|
||||||
|
return A_view;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Contains several variants of the outer M/N/K loops, and calls `A2C0` which
|
||||||
|
// loops over the inner KC and MC. Member variables avoid long argument lists.
|
||||||
|
class MMState {
|
||||||
|
public:
|
||||||
|
MMState(const Extents2D A, const MMArgs& args, const MMConfig& config)
|
||||||
|
: args_(args),
|
||||||
|
range_np_(args.per_key->ranges_np.Range(/*pkg_idx=*/0)),
|
||||||
|
mr_(config.MR()),
|
||||||
|
ranges_mc_(config.RangesOfMC(A.rows)),
|
||||||
|
ranges_kc_(config.RangesOfKC(A.cols)),
|
||||||
|
ranges_nc_(config.RangesOfNC(range_np_)),
|
||||||
|
order_(config.Order()),
|
||||||
|
inner_tasks_(config.InnerTasks()) {
|
||||||
|
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called from `MatMul` from two places: either with the next autotune config,
|
||||||
|
// or with the best config.
|
||||||
|
template <typename TB, typename TC>
|
||||||
|
HWY_NOINLINE void DispatchParallelism(const StridedViewBF A,
|
||||||
|
const MatPtrT<TB>& B,
|
||||||
|
RowPtrs<TC> C_rows) const {
|
||||||
|
/* Disabled due to unknown thread-safety issue:
|
||||||
|
static const auto zone =
|
||||||
|
args_.env->ctx.profiler.AddZone("MM.DispatchParallelism");
|
||||||
|
PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone);
|
||||||
|
*/
|
||||||
|
|
||||||
|
MMImpl::DispatchParallelism(
|
||||||
|
args_.options.parallelism,
|
||||||
|
[&](const auto& parallel) { DispatchOrder(parallel, A, B, C_rows); });
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||||
|
|
@ -616,40 +774,32 @@ class MMPerPackage {
|
||||||
// Granularity of `ForNP`. B rows produce C columns, so we
|
// Granularity of `ForNP`. B rows produce C columns, so we
|
||||||
// want a multiple of the line size to prevent false sharing.
|
// want a multiple of the line size to prevent false sharing.
|
||||||
size_t MultipleNP(size_t sizeof_TC) const {
|
size_t MultipleNP(size_t sizeof_TC) const {
|
||||||
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
|
return HWY_MAX(kNR, args_.line_bytes / sizeof_TC);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
// B is decompressed several call layers lower, but not all member functions
|
||||||
template <typename T>
|
// depend on `TB`, so pass it as an argument instead of templating the class.
|
||||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
template <typename TB, typename TC, class ParallelT>
|
||||||
size_t cols) {
|
HWY_NOINLINE void DispatchOrder(const ParallelT& parallel_policy,
|
||||||
HWY_DASSERT(c < AB.Cols());
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
RowPtrs<TC> C_rows) const {
|
||||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
|
||||||
}
|
|
||||||
|
|
||||||
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
|
|
||||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
|
||||||
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
|
|
||||||
const StridedView<TA> A, const MatPtrT<TB>& B,
|
|
||||||
RowPtrs<TC> C_rows) const {
|
|
||||||
switch (order_) {
|
switch (order_) {
|
||||||
case MMOrder::kNT:
|
case MMOrder::kNT:
|
||||||
return DoNT<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
return DoNT(parallel_policy, A, B, C_rows);
|
||||||
case MMOrder::kNT_K:
|
case MMOrder::kNT_K:
|
||||||
return DoNT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
return DoNT_K(parallel_policy, A, B, C_rows);
|
||||||
case MMOrder::kNT_MT:
|
case MMOrder::kNT_MT:
|
||||||
return DoNT_MT<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
return DoNT_MT(parallel_policy, A, B, C_rows);
|
||||||
case MMOrder::kNT_MT_K:
|
case MMOrder::kNT_MT_K:
|
||||||
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
return DoNT_MT_K(parallel_policy, A, B, C_rows);
|
||||||
default:
|
default:
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
template <typename TB, typename TC, class ParallelT>
|
||||||
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
|
HWY_INLINE void DoNT(ParallelT parallel, const StridedViewBF A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -657,14 +807,14 @@ class MMPerPackage {
|
||||||
const IndexRange& range_M = ranges_mc_.Range(0);
|
const IndexRange& range_M = ranges_mc_.Range(0);
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||||
const size_t K = range_K.Num();
|
const size_t K = range_K.Num();
|
||||||
const StridedView<TA> A_view = A.View(range_M.begin(), 0, K);
|
const StridedViewBF A_view = A.View(range_M.begin(), 0, K);
|
||||||
const size_t B_stride =
|
const size_t B_stride =
|
||||||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
|
||||||
|
|
||||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||||
MMParallelPolicyT::ForNP(
|
parallel.ForNP(
|
||||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||||
pkg_idx_, cluster_idx_,
|
args_.options.cluster_idx,
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
@ -683,8 +833,8 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
||||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
template <typename TB, typename TC, class ParallelT>
|
||||||
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
|
HWY_INLINE void DoNT_K(ParallelT parallel, const StridedViewBF A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -697,11 +847,11 @@ class MMPerPackage {
|
||||||
const IndexRange& range_nc,
|
const IndexRange& range_nc,
|
||||||
auto out_tag) HWY_ATTR {
|
auto out_tag) HWY_ATTR {
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
const StridedView<TA> A_view =
|
const StridedViewBF A_view =
|
||||||
A.View(range_mc.begin(), range_kc.begin(), kc);
|
A.View(range_mc.begin(), range_kc.begin(), kc);
|
||||||
const StridedViewBF B_storage_view(
|
const StridedViewBF B_storage_view(
|
||||||
B_storage, kc,
|
B_storage, kc,
|
||||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
|
Stride(MatPadding::kOdd, kc, sizeof(BF16), args_.line_bytes));
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
|
|
@ -711,9 +861,9 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
MMParallelPolicyT::ForNP(
|
parallel.ForNP(
|
||||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||||
pkg_idx_, cluster_idx_,
|
args_.options.cluster_idx,
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
@ -733,26 +883,26 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C directly, in parallel.
|
||||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
template <typename TB, typename TC, class ParallelT>
|
||||||
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
|
HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||||
const size_t K = range_K.Num();
|
const size_t K = range_K.Num();
|
||||||
const size_t B_stride =
|
const size_t B_stride =
|
||||||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
|
||||||
|
|
||||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||||
// except for the profiler strings and `out_tag`.
|
// except for the profiler strings and `out_tag`.
|
||||||
MMParallelPolicyT::ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
|
args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
const StridedView<TA> A_view = A.View(range_mc.begin(), 0, K);
|
const StridedViewBF A_view = A.View(range_mc.begin(), 0, K);
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||||
|
|
||||||
|
|
@ -768,14 +918,14 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
template <typename TB, typename TC, class ParallelT>
|
||||||
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
|
HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
const size_t kc_max = ranges_kc_.TaskSize();
|
||||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||||
const size_t B_stride =
|
const size_t B_stride =
|
||||||
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), line_bytes_);
|
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes);
|
||||||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||||
|
|
@ -785,7 +935,7 @@ class MMPerPackage {
|
||||||
const IndexRange& range_nc,
|
const IndexRange& range_nc,
|
||||||
auto out_tag) HWY_ATTR {
|
auto out_tag) HWY_ATTR {
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
const StridedView<TA> A_view =
|
const StridedViewBF A_view =
|
||||||
A.View(range_mc.begin(), range_kc.begin(), kc);
|
A.View(range_mc.begin(), range_kc.begin(), kc);
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
|
|
@ -795,8 +945,8 @@ class MMPerPackage {
|
||||||
C_rows);
|
C_rows);
|
||||||
}
|
}
|
||||||
}; // loop_nc
|
}; // loop_nc
|
||||||
MMParallelPolicyT::ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
|
args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
|
|
@ -816,106 +966,6 @@ class MMPerPackage {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
|
||||||
template <typename MMParallelPolicyT>
|
|
||||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
|
||||||
const StridedViewBF A_view,
|
|
||||||
MMParA par_a) const {
|
|
||||||
const IndexRange all_M(0, A.Rows());
|
|
||||||
const IndexRange all_K(0, A.Cols());
|
|
||||||
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
|
||||||
|
|
||||||
const hn::ScalableTag<BF16> dbf;
|
|
||||||
const size_t NBF = hn::Lanes(dbf);
|
|
||||||
|
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA");
|
|
||||||
|
|
||||||
const auto do_range = [&](const IndexRange& range_M,
|
|
||||||
const IndexRange& range_K,
|
|
||||||
size_t worker) HWY_ATTR {
|
|
||||||
MMZone mm_zone;
|
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
|
||||||
|
|
||||||
const size_t col0 = range_K.begin();
|
|
||||||
const size_t cols = range_K.Num();
|
|
||||||
// Must be a vector multiple, or the last range before row padding,
|
|
||||||
// otherwise `DecompressAndZeroPad` overwrites neighbors.
|
|
||||||
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
|
||||||
for (size_t row_a : range_M) {
|
|
||||||
const PackedSpan<const float> from =
|
|
||||||
MakeSpan(A.Row(row_a) + col0, cols);
|
|
||||||
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
|
|
||||||
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
|
||||||
// Verify that we zero-padded.
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
|
|
||||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
switch (par_a) {
|
|
||||||
case MMParA::kNone:
|
|
||||||
do_range(all_M, all_K, /*worker=*/0);
|
|
||||||
break;
|
|
||||||
case MMParA::kK1:
|
|
||||||
case MMParA::kK2:
|
|
||||||
case MMParA::kK4: {
|
|
||||||
const size_t inner_tasks = static_cast<size_t>(par_a);
|
|
||||||
// At least one vector, otherwise DecompressAndZeroPad will add
|
|
||||||
// padding, which might overwrite neighboring tasks. Also a whole cache
|
|
||||||
// line to avoid false sharing.
|
|
||||||
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
|
||||||
|
|
||||||
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
|
|
||||||
pkg_idx_, cluster_idx_,
|
|
||||||
[&](const IndexRange& range_K, size_t worker) {
|
|
||||||
do_range(all_M, range_K, worker);
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case MMParA::kM:
|
|
||||||
MMParallelPolicyT::ForRangeMC(
|
|
||||||
args_.env->ctx, all_M, pkg_idx_, cluster_idx_,
|
|
||||||
[&](size_t row_a, size_t worker) {
|
|
||||||
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Autotuning wrapper for `DoDecompressA`.
|
|
||||||
template <typename MMParallelPolicyT>
|
|
||||||
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
|
||||||
const StridedViewBF A_view) const {
|
|
||||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
|
||||||
|
|
||||||
if (HWY_LIKELY(autotune.Best())) {
|
|
||||||
return DoDecompressA<MMParallelPolicyT>(A, A_view, *autotune.Best());
|
|
||||||
}
|
|
||||||
|
|
||||||
// First call: generate candidates.
|
|
||||||
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
|
|
||||||
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
|
|
||||||
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
|
|
||||||
other};
|
|
||||||
autotune.SetCandidates(candidates);
|
|
||||||
}
|
|
||||||
|
|
||||||
const MMParA& par_a = autotune.NextConfig();
|
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
|
||||||
DoDecompressA<MMParallelPolicyT>(A, A_view, par_a);
|
|
||||||
const uint64_t t1 =
|
|
||||||
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
|
||||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
|
||||||
if (HWY_UNLIKELY(args_.env->print_measurement && autotune.ShouldPrint())) {
|
|
||||||
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
|
||||||
static_cast<double>(min_elapsed) /
|
|
||||||
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
||||||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||||
// thanks to its large table lookups, and less so on other targets.
|
// thanks to its large table lookups, and less so on other targets.
|
||||||
|
|
@ -928,7 +978,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Neither A nor B require padding because `LoopKC` handles remainders.
|
// Neither A nor B require padding because `LoopKC` handles remainders.
|
||||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
return MMImpl::View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||||
}
|
}
|
||||||
|
|
||||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||||
|
|
@ -951,8 +1001,6 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
|
|
||||||
const MMArgs args_; // copy for locality
|
const MMArgs args_; // copy for locality
|
||||||
const size_t pkg_idx_;
|
|
||||||
const size_t cluster_idx_; // 0 for sequential and nested.
|
|
||||||
|
|
||||||
const IndexRange range_np_;
|
const IndexRange range_np_;
|
||||||
// From MMConfig:
|
// From MMConfig:
|
||||||
|
|
@ -962,52 +1010,7 @@ class MMPerPackage {
|
||||||
const IndexRangePartition ranges_nc_;
|
const IndexRangePartition ranges_nc_;
|
||||||
const MMOrder order_;
|
const MMOrder order_;
|
||||||
const size_t inner_tasks_;
|
const size_t inner_tasks_;
|
||||||
const size_t line_bytes_;
|
}; // MMState
|
||||||
}; // MMPerPackage
|
|
||||||
|
|
||||||
// Stateless, wraps member functions.
|
|
||||||
struct MMImpl {
|
|
||||||
// Returns existing entry for the given key or -1.
|
|
||||||
static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) {
|
|
||||||
const hwy::Span<const uint64_t> all_keys = keys.Keys();
|
|
||||||
// TODO: SIMD scan
|
|
||||||
for (size_t i = 0; i < all_keys.size(); ++i) {
|
|
||||||
if (all_keys[i] == key) return static_cast<intptr_t>(i);
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Called from `MatMul` from two places: either with the next autotune config,
|
|
||||||
// or with the best config.
|
|
||||||
template <typename TA, typename TB, typename TC>
|
|
||||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
|
||||||
const MMConfig& config, MMOptions options) {
|
|
||||||
PROFILER_ZONE("MM.DoMatMul");
|
|
||||||
const size_t pkg_idx = 0;
|
|
||||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
|
||||||
|
|
||||||
switch (options.parallelism_type) {
|
|
||||||
case ParallelismType::kNested:
|
|
||||||
HWY_DASSERT(options.cluster_idx == 0);
|
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
|
||||||
range_np)(MMNestedParallelPolicy(), A, B, C_rows);
|
|
||||||
break;
|
|
||||||
case ParallelismType::kSequential:
|
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
|
||||||
range_np)(MMSequentialPolicy(), A, B, C_rows);
|
|
||||||
case ParallelismType::kWithinCluster:
|
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
|
||||||
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
HWY_ABORT("Parallelism type %s not implemented.",
|
|
||||||
static_cast<int>(options.parallelism_type));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||||
//
|
//
|
||||||
|
|
@ -1033,17 +1036,19 @@ template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
|
HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
|
||||||
|
RowPtrs<TC> C_rows =
|
||||||
|
GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]);
|
||||||
|
|
||||||
const Allocator& allocator = env.ctx.allocator;
|
const Allocator& allocator = env.ctx.allocator;
|
||||||
const size_t M = A.Rows();
|
const size_t M = A.Rows();
|
||||||
const size_t K = A.Cols();
|
const size_t K = A.Cols();
|
||||||
const size_t N = B.Rows();
|
const size_t N = B.Rows();
|
||||||
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
||||||
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
|
intptr_t index = MMImpl::IndexOfKey(key, env.keys[options.cluster_idx]);
|
||||||
// First time we see this shape/key.
|
// First time we see this shape/key.
|
||||||
if (HWY_UNLIKELY(index < 0)) {
|
if (HWY_UNLIKELY(index < 0)) {
|
||||||
env.keys.Append(key, allocator);
|
env.keys[options.cluster_idx].Append(key, allocator);
|
||||||
|
|
||||||
size_t max_packages = kMaxPackages;
|
size_t max_packages = kMaxPackages;
|
||||||
// For low-batch, multiple sockets only help if binding is enabled.
|
// For low-batch, multiple sockets only help if binding is enabled.
|
||||||
|
|
@ -1052,16 +1057,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalidates `MMAutoTune::Best()`
|
// invalidates `MMAutoTune::Best()`
|
||||||
index = env.per_key.size();
|
std::vector<MMPerKey>& stored_keys = env.per_key[options.cluster_idx];
|
||||||
env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
|
index = stored_keys.size();
|
||||||
|
stored_keys.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
|
||||||
}
|
}
|
||||||
MMPerKey& per_key = env.per_key[index];
|
MMPerKey& per_key = env.per_key[options.cluster_idx][index];
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
|
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
add);
|
add, options);
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options);
|
const MMState state(A.Extents(), args, *tuner.Best());
|
||||||
|
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
|
||||||
|
state.DispatchParallelism(A_view, B, C_rows);
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1089,7 +1097,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options);
|
MMState state(A.Extents(), args, cfg);
|
||||||
|
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
|
||||||
|
state.DispatchParallelism(A_view, B, C_rows);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
|
|
|
||||||
|
|
@ -405,20 +405,15 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
|
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
|
||||||
// Create storage per cluster. This only applies to in-cluster parallelism.
|
// Create storage per cluster. This only applies to in-cluster parallelism.
|
||||||
// For nested and sequential parallelism, a single MMStorage is used.
|
// For nested and sequential parallelism, a single MMStorage is used.
|
||||||
size_t num_packages = ctx.topology.NumPackages();
|
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
|
||||||
size_t num_clusters = 0;
|
|
||||||
for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) {
|
|
||||||
num_clusters += ctx.topology.NumClusters(pkg_idx);
|
|
||||||
}
|
|
||||||
storage.reserve(num_clusters);
|
storage.reserve(num_clusters);
|
||||||
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||||
storage.push_back(MMStorage(ctx));
|
storage.push_back(MMStorage(ctx));
|
||||||
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
|
||||||
}
|
}
|
||||||
|
|
||||||
char cpu100[100];
|
char cpu100[100];
|
||||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
|
|
||||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
|
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
|
||||||
|
|
|
||||||
235
ops/matmul.h
235
ops/matmul.h
|
|
@ -58,147 +58,127 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||||
size_t N, size_t sizeof_TC, size_t nr);
|
size_t N, size_t sizeof_TC, size_t nr);
|
||||||
|
|
||||||
struct MMOptions {
|
struct MMOptions {
|
||||||
ParallelismType parallelism_type = ParallelismType::kNested;
|
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||||
uint8_t cluster_idx = 0;
|
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MMSequentialPolicy {
|
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
||||||
template <class Func>
|
|
||||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
|
||||||
const Func& func) {
|
|
||||||
func(/*pkg_idx=*/0);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
struct MMParallelNone {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
size_t nx_multiple, size_t inner_tasks, size_t cluster_idx,
|
||||||
size_t cluster_idx, const Func& func) {
|
const Func& func) const {
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
func(range_np, worker);
|
||||||
func(range_np, base_idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
const IndexRangePartition& ranges_mc,
|
const IndexRangePartition& ranges_mc,
|
||||||
const IndexRangePartition& ranges_nc,
|
const IndexRangePartition& ranges_nc, size_t cluster_idx,
|
||||||
size_t pkg_idx, size_t cluster_idx,
|
const Func& func) const {
|
||||||
const Func& func) {
|
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
|
||||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
||||||
const IndexRange range_mc = ranges_mc.Range(i);
|
const IndexRange range_mc = ranges_mc.Range(i);
|
||||||
for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) {
|
for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) {
|
||||||
const IndexRange range_nc = ranges_nc.Range(j);
|
const IndexRange range_nc = ranges_nc.Range(j);
|
||||||
func(range_mc, range_nc, base_idx);
|
func(range_mc, range_nc, worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
size_t pkg_idx, size_t cluster_idx, const Func& func) {
|
size_t cluster_idx, const Func& func) const {
|
||||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
|
||||||
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
||||||
func(row_a, base_idx);
|
func(row_a, worker);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MMClusterParallelPolicy {
|
struct MMParallelWithinCluster {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
const Func& func) {
|
size_t nx_multiple, size_t inner_tasks, size_t cluster_idx,
|
||||||
func(/*pkg_idx=*/0);
|
const Func& func) const {
|
||||||
}
|
|
||||||
|
|
||||||
template <class Func>
|
|
||||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
|
||||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
|
||||||
size_t cluster_idx, const Func& func) {
|
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
|
|
||||||
|
const size_t pkg_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
|
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
|
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
ParallelizeOneRange(worker_ranges, cluster,
|
ParallelizeOneRange(worker_ranges, cluster,
|
||||||
[&](const IndexRange& worker_range, size_t worker) {
|
[&](const IndexRange& worker_range, size_t worker) {
|
||||||
func(worker_range, worker);
|
func(worker_range, base + worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
const IndexRangePartition& ranges_mc,
|
const IndexRangePartition& ranges_mc,
|
||||||
const IndexRangePartition& ranges_nc,
|
const IndexRangePartition& ranges_nc, size_t cluster_idx,
|
||||||
size_t pkg_idx, size_t cluster_idx,
|
const Func& func) const {
|
||||||
const Func& func) {
|
const size_t pkg_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
|
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
|
|
||||||
// Low-batch: avoid Divide/Remainder.
|
// Low-batch: avoid Divide/Remainder.
|
||||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||||
ParallelizeOneRange(ranges_nc, cluster,
|
ParallelizeOneRange(ranges_nc, cluster,
|
||||||
[&](const IndexRange& range_nc, size_t thread) {
|
[&](const IndexRange& range_nc, size_t worker) {
|
||||||
func(ranges_mc.Range(0), range_nc, thread);
|
func(ranges_mc.Range(0), range_nc, base + worker);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
ParallelizeTwoRanges(
|
ParallelizeTwoRanges(
|
||||||
ranges_mc, ranges_nc, cluster,
|
ranges_mc, ranges_nc, cluster,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t thread) { func(range_mc, range_nc, thread); });
|
size_t worker) { func(range_mc, range_nc, base + worker); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
size_t pkg_idx, size_t cluster_idx, const Func& func) {
|
size_t cluster_idx, const Func& func) const {
|
||||||
|
const size_t pkg_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
cluster.Run(range_mc.begin(), range_mc.end(),
|
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
[&](uint64_t row_a, size_t thread) { func(row_a, thread); });
|
|
||||||
|
cluster.Run(
|
||||||
|
range_mc.begin(), range_mc.end(),
|
||||||
|
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MMNestedParallelPolicy {
|
struct MMParallelHierarchical {
|
||||||
template <class Func>
|
|
||||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
|
||||||
const Func& func) {
|
|
||||||
if constexpr (kMaxPackages > 1) {
|
|
||||||
ctx.pools.AllPackages().Run(
|
|
||||||
0, HWY_MIN(max_packages, ctx.pools.NumPackages()),
|
|
||||||
[&](uint64_t task, size_t pkg_idx) {
|
|
||||||
HWY_DASSERT(task == pkg_idx);
|
|
||||||
(void)task;
|
|
||||||
func(pkg_idx);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
func(/*pkg_idx=*/0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||||
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
||||||
// `cluster_idx` is not used here as all clusters within a package are used.
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
size_t nx_multiple, size_t inner_tasks,
|
||||||
size_t /*cluster_idx*/, const Func& func) {
|
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
|
||||||
|
const Func& func) const {
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
HWY_DASSERT(caller_cluster_idx == 0);
|
||||||
|
|
||||||
// Single cluster: parallel-for over static partition of `range_np`.
|
// Single cluster: parallel-for over static partition of `range_np`.
|
||||||
|
const size_t pkg_idx = 0;
|
||||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||||
const size_t num_clusters = all_clusters.NumWorkers();
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
if (num_clusters == 1) {
|
if (num_clusters == 1) {
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0);
|
const size_t cluster_idx = 0;
|
||||||
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
return ParallelizeOneRange(
|
return ParallelizeOneRange(
|
||||||
worker_ranges, cluster,
|
worker_ranges, cluster,
|
||||||
[&](const IndexRange& worker_range, size_t thread) {
|
[&](const IndexRange& worker_range, size_t worker) {
|
||||||
func(worker_range, pkg_base + thread);
|
func(worker_range, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -210,28 +190,29 @@ struct MMNestedParallelPolicy {
|
||||||
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t cluster_base =
|
const size_t cluster_base =
|
||||||
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
worker_ranges, cluster,
|
worker_ranges, cluster,
|
||||||
[&](const IndexRange& worker_range, size_t thread) {
|
[&](const IndexRange& worker_range, size_t worker) {
|
||||||
func(worker_range, cluster_base + thread);
|
func(worker_range, cluster_base + worker);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
||||||
// rows). Calls `func(range_mc, range_nc, worker)`.
|
// rows). Calls `func(range_mc, range_nc, worker)`.
|
||||||
// `cluster_idx` is not used here as all clusters within a package are used.
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
const IndexRangePartition& ranges_mc,
|
const IndexRangePartition& ranges_mc,
|
||||||
const IndexRangePartition& ranges_nc,
|
const IndexRangePartition& ranges_nc,
|
||||||
size_t pkg_idx, size_t /*cluster_idx*/,
|
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
|
||||||
const Func& func) {
|
const Func& func) const {
|
||||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
const size_t pkg_idx = 0;
|
||||||
|
HWY_DASSERT(caller_cluster_idx == 0);
|
||||||
|
|
||||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||||
const size_t num_clusters = all_clusters.NumWorkers();
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
|
|
@ -243,16 +224,14 @@ struct MMNestedParallelPolicy {
|
||||||
// Low-batch: avoid Divide/Remainder.
|
// Low-batch: avoid Divide/Remainder.
|
||||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||||
return ParallelizeOneRange(
|
return ParallelizeOneRange(
|
||||||
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t thread) {
|
ranges_nc, cluster, [&](const IndexRange& range_nc, size_t worker) {
|
||||||
func(ranges_mc.Range(0), range_nc, pkg_base + thread);
|
func(ranges_mc.Range(0), range_nc, worker);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
return ParallelizeTwoRanges(
|
return ParallelizeTwoRanges(
|
||||||
ranges_mc, ranges_nc, cluster,
|
ranges_mc, ranges_nc, cluster,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t thread) {
|
size_t worker) { func(range_mc, range_nc, worker); });
|
||||||
func(range_mc, range_nc, pkg_base + thread);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -262,25 +241,23 @@ struct MMNestedParallelPolicy {
|
||||||
ranges_nc, all_clusters,
|
ranges_nc, all_clusters,
|
||||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||||
const size_t cluster_base =
|
const size_t cluster_base =
|
||||||
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
ParallelizeOneRange(ranges_mc, cluster,
|
ParallelizeOneRange(ranges_mc, cluster,
|
||||||
[&](const IndexRange& range_mc, size_t thread) {
|
[&](const IndexRange& range_mc, size_t worker) {
|
||||||
func(range_mc, range_nc, cluster_base + thread);
|
func(range_mc, range_nc, cluster_base + worker);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls `func(row_a, worker)` in parallel.
|
// Calls `func(row_a, worker)` in parallel.
|
||||||
// `cluster_idx` is not used here as all clusters within a package are used.
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
size_t pkg_idx, size_t /*cluster_idx*/,
|
size_t caller_cluster_idx, const Func& func) const {
|
||||||
const Func& func) {
|
HierarchicalParallelFor(range_mc.Num(), ctx.pools,
|
||||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
[&](size_t task, size_t worker) {
|
||||||
ctx.pools.Pool(pkg_idx).Run(
|
func(range_mc.begin() + task, worker);
|
||||||
range_mc.begin(), range_mc.end(),
|
});
|
||||||
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -340,27 +317,22 @@ class MMStorage {
|
||||||
// Internally threaded; must not be called concurrently with the same
|
// Internally threaded; must not be called concurrently with the same
|
||||||
// `ThreadingContext` (used via `parallel`).
|
// `ThreadingContext` (used via `parallel`).
|
||||||
MMStorage(ThreadingContext& ctx) {
|
MMStorage(ThreadingContext& ctx) {
|
||||||
// Per-package allocation so each can decompress A into its own copy.
|
Allocator& allocator = ctx.allocator;
|
||||||
// Must be padded, see `DoDecompressA`.
|
const size_t pkg_idx = 0;
|
||||||
// Default to nested parallel policy.
|
|
||||||
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
|
|
||||||
Allocator& allocator = ctx.allocator;
|
|
||||||
|
|
||||||
// 0.5 GiB per package.
|
// 0.5 GiB per package. Must be padded, see `DoDecompressA`.
|
||||||
pkg_A_[pkg_idx].reset(
|
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||||
new MatStorageT<BF16>("pkg_A", Extents2D(kMaxBatchSize, kMaxK),
|
"pkg_A", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd));
|
||||||
allocator, MatPadding::kOdd));
|
|
||||||
|
|
||||||
if (allocator.ShouldBind()) {
|
if (allocator.ShouldBind()) {
|
||||||
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||||
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
|
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
|
||||||
pkg_A_[pkg_idx]->ElementBytes();
|
pkg_A_[pkg_idx]->ElementBytes();
|
||||||
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
|
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
|
||||||
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
|
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
|
||||||
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns per-package matrix view. Converting A=F32 to BF16 up-front is
|
// Returns per-package matrix view. Converting A=F32 to BF16 up-front is
|
||||||
|
|
@ -735,16 +707,18 @@ struct MatMulEnv {
|
||||||
bool print_best = false;
|
bool print_best = false;
|
||||||
|
|
||||||
std::vector<MMStorage> storage;
|
std::vector<MMStorage> storage;
|
||||||
MMKeys keys;
|
MMKeys keys[kMaxClusters];
|
||||||
std::vector<MMPerKey> per_key;
|
std::vector<MMPerKey> per_key[kMaxClusters];
|
||||||
|
|
||||||
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
|
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
|
||||||
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
|
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
|
||||||
// writes to differing KV positions per query / output row.
|
// writes to differing KV positions per query / output row.
|
||||||
// The first entry is sufficient for any C argument, but also potentially
|
// The first `num_clusters` entries are sufficient for any C argument, and
|
||||||
// overwritten by each MatMul. Subsequent entries are precomputed for tensors
|
// must be indexed by `options.cluster_idx`. Note that they are potentially
|
||||||
// and not overwritten. Per-tensor allocations make it likelier that asan
|
// overwritten by each `MatMul`. Subsequent entries are for specific tensors
|
||||||
// detects bugs such as use after free, overrun, and dangling references.
|
// and only written once by their allocator. A per-tensor allocation makes it
|
||||||
|
// likelier that asan detects bugs such as use after free, overrun, and
|
||||||
|
// dangling references.
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -752,14 +726,21 @@ struct MatMulEnv {
|
||||||
// Reduces register pressure compared to individual values/references.
|
// Reduces register pressure compared to individual values/references.
|
||||||
struct MMArgs {
|
struct MMArgs {
|
||||||
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
|
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
|
||||||
const float* HWY_RESTRICT add)
|
const float* HWY_RESTRICT add, MMOptions options)
|
||||||
: env(&env), per_key(&per_key), scale(scale), add(add) {}
|
: env(&env),
|
||||||
|
per_key(&per_key),
|
||||||
|
scale(scale),
|
||||||
|
add(add),
|
||||||
|
options(options),
|
||||||
|
line_bytes(env.ctx.allocator.LineBytes()) {}
|
||||||
|
|
||||||
MatMulEnv* env;
|
MatMulEnv* env;
|
||||||
MMPerKey* per_key;
|
MMPerKey* per_key;
|
||||||
|
|
||||||
double scale;
|
double scale;
|
||||||
const float* HWY_RESTRICT add;
|
const float* HWY_RESTRICT add;
|
||||||
|
MMOptions options;
|
||||||
|
size_t line_bytes;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||||
|
|
|
||||||
|
|
@ -501,12 +501,12 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
HWY_DASSERT(activations.SameShape(out));
|
HWY_DASSERT(activations.SameShape(out));
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
ParallelFor(
|
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
|
||||||
ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools,
|
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
||||||
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
out.Row(token_idx), activations.Cols(), ctx.profiler,
|
||||||
out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
|
worker);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -517,12 +517,12 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
ParallelFor(
|
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
|
||||||
ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx,
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
RMSNormInplace(weights_t->PackedScale1(),
|
||||||
RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx),
|
inout.Row(token_idx), inout.Cols(),
|
||||||
inout.Cols(), ctx.profiler, worker);
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -548,8 +548,8 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||||
ThreadingContext& ctx,
|
ThreadingContext& ctx,
|
||||||
size_t cluster_idx = 0) {
|
size_t cluster_idx = 0) {
|
||||||
HWY_DASSERT(out.SameShape(x));
|
HWY_DASSERT(out.SameShape(x));
|
||||||
ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools,
|
ParallelFor(ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
|
||||||
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||||
ctx.profiler, worker);
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
|
|
@ -782,8 +782,8 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
||||||
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||||
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||||
if (cap == 0.0f) return;
|
if (cap == 0.0f) return;
|
||||||
ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools,
|
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
|
||||||
cluster_idx, [&](uint64_t task, size_t worker) {
|
[&](uint64_t task, size_t worker) {
|
||||||
if (non_eos.Get(task)) {
|
if (non_eos.Get(task)) {
|
||||||
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
|
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
|
||||||
worker);
|
worker);
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,8 @@ namespace gcpp {
|
||||||
// is disabled if this is 1.
|
// is disabled if this is 1.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxPackages = 1;
|
HWY_INLINE_VAR constexpr size_t kMaxPackages = 1;
|
||||||
|
|
||||||
|
HWY_INLINE_VAR constexpr size_t kMaxClusters = 128; // TODO: shrink
|
||||||
|
|
||||||
// TODO: extend to 16k after updating non_eos.
|
// TODO: extend to 16k after updating non_eos.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -326,7 +326,8 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||||
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||||
// over clusters of ONE package, then within each cluster.
|
// over clusters of ONE package, then within each cluster.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
void HierarchicalParallelFor(size_t num_tasks, NestedPools& pools,
|
||||||
|
const Func& func) {
|
||||||
// Even if there are multiple packages, we only use the first.
|
// Even if there are multiple packages, we only use the first.
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
|
|
||||||
|
|
@ -356,59 +357,6 @@ void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Which pool(s) to use for parallelizing:
|
|
||||||
enum class ParallelismType : uint8_t {
|
|
||||||
// None: single-threaded loop on the calling thread.
|
|
||||||
kSequential,
|
|
||||||
// One thread per cluster within the first package; or one per core if there
|
|
||||||
// is only one cluster. Use for few or lightweight tasks, or to maximize
|
|
||||||
// memory bandwidth availability.
|
|
||||||
kAcrossClusters,
|
|
||||||
// All cores within the cluster identified by `cluster_idx`. Use if already
|
|
||||||
// within a `kAcrossClusters` parallel-for, or if latency is more important
|
|
||||||
// than memory bandwidth.
|
|
||||||
kWithinCluster,
|
|
||||||
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
|
|
||||||
// utilizes all cores, but has higher fork-join overhead (two barriers); use
|
|
||||||
// if there are many or heavy tasks.
|
|
||||||
kNested,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
|
|
||||||
// number/type of workers determined by `parallelism`. `cluster_idx` is only
|
|
||||||
// used if `parallelism == kWithinCluster`.
|
|
||||||
template <class Func>
|
|
||||||
void ParallelFor(ParallelismType parallelism, size_t num_tasks,
|
|
||||||
NestedPools& pools, size_t cluster_idx, const Func& func) {
|
|
||||||
if (cluster_idx != 0) {
|
|
||||||
// If already running across clusters, must not use across-cluster modes.
|
|
||||||
HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters &&
|
|
||||||
parallelism != ParallelismType::kNested);
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t pkg_idx = 0;
|
|
||||||
switch (parallelism) {
|
|
||||||
case ParallelismType::kSequential:
|
|
||||||
for (size_t task = 0; task < num_tasks; ++task) {
|
|
||||||
func(task, /*worker=*/0);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
|
|
||||||
case ParallelismType::kAcrossClusters:
|
|
||||||
return pools.Pool(pkg_idx).Run(
|
|
||||||
0, num_tasks,
|
|
||||||
[&](uint64_t task, size_t worker) { func(task, worker); });
|
|
||||||
|
|
||||||
case ParallelismType::kWithinCluster:
|
|
||||||
return pools.Cluster(pkg_idx, cluster_idx)
|
|
||||||
.Run(0, num_tasks,
|
|
||||||
[&](uint64_t task, size_t worker) { func(task, worker); });
|
|
||||||
|
|
||||||
case ParallelismType::kNested:
|
|
||||||
return NestedParallelFor(num_tasks, pools, func);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,95 @@ struct ThreadingContext {
|
||||||
NestedPools pools;
|
NestedPools pools;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Describes the strategy for distributing parallel work across cores.
|
||||||
|
enum class ParallelismStrategy : uint8_t {
|
||||||
|
// Execute using a single-threaded loop on the calling thread. The `worker`
|
||||||
|
// index passed to the user's `Func` is unique across clusters.
|
||||||
|
kNone,
|
||||||
|
// One thread per cluster within the first package. The `worker` index passed
|
||||||
|
// to the user's `Func` is a `cluster_idx <= NumClusters()`. Some CPUs may
|
||||||
|
// only have a single cluster, hence `Func` should also contain a nested
|
||||||
|
// `ParallelFor` with `kWithinCluster`.
|
||||||
|
kAcrossClusters,
|
||||||
|
// All cores within the cluster identified by `cluster_idx`. The `worker`
|
||||||
|
// index passed to the user's `Func` is unique across clusters. Choose this
|
||||||
|
// strategy if already within a `ParallelFor` call with `kAcrossClusters`,
|
||||||
|
// or latency is more important than memory bandwidth.
|
||||||
|
kWithinCluster,
|
||||||
|
// Equivalent to `kAcrossClusters` if there are multiple clusters, otherwise
|
||||||
|
// `kWithinCluster`. Use for few or lightweight tasks (this only uses a
|
||||||
|
// single pool and barrier), or to maximize memory bandwidth availability.
|
||||||
|
kFlat,
|
||||||
|
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
|
||||||
|
// utilizes all cores, but has higher fork-join overhead (two barriers); use
|
||||||
|
// if there are many or heavy tasks.
|
||||||
|
kHierarchical,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
|
||||||
|
// number/type of workers determined by `parallelism`. `cluster_idx` is for
|
||||||
|
// `parallelism == kWithinCluster`, and should be 0 if unknown.
|
||||||
|
template <class Func>
|
||||||
|
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
|
||||||
|
ThreadingContext& ctx, size_t cluster_idx, const Func& func) {
|
||||||
|
HWY_DASSERT(ctx.topology.NumPackages() == 1);
|
||||||
|
const size_t pkg_idx = 0;
|
||||||
|
|
||||||
|
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters(pkg_idx));
|
||||||
|
if (cluster_idx != 0) {
|
||||||
|
// If already running across clusters, only use within-cluster modes.
|
||||||
|
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
|
||||||
|
parallelism == ParallelismStrategy::kWithinCluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (parallelism) {
|
||||||
|
case ParallelismStrategy::kNone: {
|
||||||
|
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
|
for (size_t task = 0; task < num_tasks; ++task) {
|
||||||
|
func(task, worker);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
case ParallelismStrategy::kAcrossClusters:
|
||||||
|
return ctx.pools.AllClusters(pkg_idx).Run(
|
||||||
|
0, num_tasks,
|
||||||
|
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
|
||||||
|
|
||||||
|
case ParallelismStrategy::kWithinCluster: {
|
||||||
|
// Ensure the worker argument is unique across clusters, because it is
|
||||||
|
// used for TLS indexing for example in profiler.h.
|
||||||
|
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
|
return ctx.pools.Cluster(pkg_idx, cluster_idx)
|
||||||
|
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
|
||||||
|
func(task, base + worker);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
case ParallelismStrategy::kFlat: {
|
||||||
|
// Check for single cluster; if not, we must compute `cluster_base` for
|
||||||
|
// consistent and non-overlapping worker indices.
|
||||||
|
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||||
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
|
if (num_clusters == 1) {
|
||||||
|
return ctx.pools.Cluster(pkg_idx, cluster_idx)
|
||||||
|
.Run(0, num_tasks,
|
||||||
|
[&](uint64_t task, size_t worker) { func(task, worker); });
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.pools.AllClusters(pkg_idx).Run(
|
||||||
|
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
|
||||||
|
const size_t worker =
|
||||||
|
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
|
func(task, worker);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
case ParallelismStrategy::kHierarchical:
|
||||||
|
return HierarchicalParallelFor(num_tasks, ctx.pools, func);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_CONTEXT_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_CONTEXT_H_
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue