1.03x speedup: fused FFN

matmul-inl: support CView=StridedView or RowPtrs; rename to C_MC_NC
matmul.cc: Allow 1 more rep for MC/NC to allow half-sized tiles, which helps.
PiperOrigin-RevId: 807291701
This commit is contained in:
Jan Wassenberg 2025-09-15 10:25:59 -07:00 committed by Copybara-Service
parent 59db30e209
commit f3bc1c17da
11 changed files with 488 additions and 247 deletions

View File

@ -37,7 +37,6 @@ class GemmaBatchBench : public ::testing::Test {
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
s_env->SetMaxGeneratedTokens(24);
s_env->MutableConfig().temperature = 0.0f; // deterministic
s_env->MutableConfig().verbosity = 2;
std::vector<std::string> replies;
@ -92,15 +91,18 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
inputs.push_back(questions[qpos++]);
if (qpos == questions.size()) qpos = 0;
}
s_env->SetMaxGeneratedTokens(24);
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
++i) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
PROFILER_PRINT_RESULTS();
// Run again: prefill will be faster due to autotuning. Fewer decode steps
// because those are already fast.
s_env->SetMaxGeneratedTokens(3);
s_env->SetMaxGeneratedTokens(2);
responses = BatchGemmaReply(inputs);
PROFILER_PRINT_RESULTS();

View File

@ -36,6 +36,10 @@ HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
#ifndef GEMMA_FUSED_FFN
#define GEMMA_FUSED_FFN 1
#endif // !GEMMA_FUSED_FFN
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping {
GEMMA_IT,

View File

@ -43,6 +43,7 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// For use by Vit even if !GEMMA_FUSED_FFN.
template <typename T1, typename T2>
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
@ -64,7 +65,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
});
}
// No C2 multiplier.
// No C2 multiplier - used by Vit.
template <class Mat>
void ActivationBatched(
ActivationType activation, Mat& c1, ThreadingContext& ctx,
@ -80,6 +81,34 @@ void ActivationBatched(
});
}
#if GEMMA_FUSED_FFN
// Called during `TwoMatMul`.
static inline void Activation(ActivationType activation, const RowPtrsBF C1,
const IndexRange range_r,
const IndexRange range_c, const StridedViewBF C2,
hwy::Profiler& p, const size_t worker) {
static const auto zone = p.AddZone("Gen.ActivationFused");
PROFILER_ZONE3(p, worker, zone);
const size_t cols = range_c.Num();
HWY_DASSERT(C2.Cols() == cols);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
// ActivationType::Gelu
// Gated: Gelu(c1) * c2.
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
Decompress1AndCompressInplace(
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
[](DF df, VF v1, VF v2)
HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); });
}
}
#else
template <class Mat1, class Mat2>
HWY_NOINLINE void ActivationBatched(
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
@ -102,6 +131,8 @@ HWY_NOINLINE void ActivationBatched(
}
}
#endif // GEMMA_FUSED_FFN
template <typename T2, class LayerWeights>
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
MatPtrT<float>& HWY_RESTRICT x,
@ -126,28 +157,32 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
const LayerConfig& layer_config = layer.layer_config;
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
const bool add_bias = layer_config.ff_biases;
const float* bias1 =
add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr;
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
const float* output_bias =
add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr;
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
#if GEMMA_FUSED_FFN
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) {
Activation(layer_config.activation, C1, range_r, range_c, C2,
env.ctx.profiler, worker);
};
MMOptions options;
options.SetFunc(fused);
CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1,
layer.gating_einsum_w2, env, activations.C1, options);
#else
// Compute the hidden layer activations.
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env,
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env,
activations.C1);
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env,
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env,
activations.C2);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
env.ctx);
#endif
// Hidden layer -> output layer.
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
activations.ffw_out);
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -155,14 +155,14 @@ class MMStoreHorizontalSumsIntoC {
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CView>
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3,
const float scale, const float* HWY_RESTRICT add,
const size_t imc, Tag tag, CView C_rows) const {
const size_t imc, Tag tag, CView C_MC_NR) const {
const V4 vscale = hn::Set(d4, scale);
HWY_ALIGN static constexpr float kZero[4] = {};
const V4 vadd = hn::Load(d4, add ? add : kZero);
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows);
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_MC_NR);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_MC_NR);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_MC_NR);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_MC_NR);
}
private:
@ -202,10 +202,10 @@ class MMStoreHorizontalSumsIntoC {
class Tag, class CView>
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
VF4 vadd, Tag, const size_t imc,
CView C_view) {
CView C_MC_NR) {
if constexpr (kRow < kRowsAC) {
using TC = hwy::RemoveCvRef<decltype(C_view.Row(0)[0])>;
TC* HWY_RESTRICT pos = C_view.Row(imc + kRow);
using TC = hwy::RemoveCvRef<decltype(C_MC_NR.Row(0)[0])>;
TC* HWY_RESTRICT pos = C_MC_NR.Row(imc + kRow);
const hn::Rebind<TC, DF4> dc4;
if constexpr (hwy::IsSame<Tag, MMAddC>()) {
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
@ -268,9 +268,9 @@ class MMDecompress {
} else {
// Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16. We also only
// have a single MMStorage.
// have a single MMEntireA.
HWY_ASSERT(options.cluster_idx == 0);
const StridedViewBF A_view = env.storage.A(A.Extents());
const StridedViewBF A_view = env.A_BF.A(A.Extents());
AutotuneDecompressA(A, A_view, autotune, env, options);
return A_view;
}
@ -387,111 +387,52 @@ class MMDecompress {
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
class MMKernel {
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
// allocation avoids passing a worker index.
static constexpr size_t B_stride_max =
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
public:
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
// at row/col 0. `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
// Called by B3A2C0 and by callers that hoist `A_view`.
template <class Tag, class CView>
static HWY_INLINE void A2C0(const StridedViewBF A_view,
const StridedViewBF B_view, size_t mr,
const IndexRange& range_mc, size_t kc,
const float scale, const float* HWY_RESTRICT add,
Tag tag, CView C_view) {
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
const size_t mc = range_mc.Num();
size_t imc = 0;
// M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) {
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
}
return;
}
// AVX2 (16 registers)
if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) {
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
}
}
if (HWY_UNLIKELY(imc != mc)) {
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
}
return;
}
HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) {
LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view);
}
}
const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
imc += 2;
}
if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
imc += 1;
}
HWY_DASSERT(imc == mc);
}
static constexpr size_t B_storage_max = kNR * B_stride_max;
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
// `ForeachKC` and when there is only a single KC task.
template <typename TB, typename TC, typename Tag>
// `mc x kc` of A, `nc x kc` of B, and updates the `mc x nc` `C_MC_NC`.
// `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
template <typename TB, typename Tag, class CView>
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
const IndexRange& range_mc, const IndexRange& range_kc,
const IndexRange& range_nc, const MMArgs& args,
Tag out_tag, RowPtrs<TC> C) {
HWY_ALIGN BF16 B_storage[B_storage_max];
Tag out_tag, CView C_MC_NC) {
const size_t kc = range_kc.Num();
const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc);
// Upper bound on per-worker storage for `kNR` row ranges of B. Stack
// allocation avoids passing a worker index.
constexpr size_t B_stride_max =
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
HWY_ALIGN BF16 B_storage[kNR * B_stride_max];
const size_t B_stride =
Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes);
const StridedViewBF B_storage_view(B_storage, kc, B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
const float scale = args.scale_A * B.Scale();
for (size_t inc = 0; inc < range_nc.Num(); inc += kNR) {
// For `add` and `B`, which are global, unlike `C_MC_NC`.
const size_t row_b = range_nc.begin() + inc;
const StridedViewBF B_view =
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
const RowPtrs<TC> C_view = C.View(range_mc.begin(), row_b);
const CView C_MC_NR = C_MC_NC.View(0, inc, kNR);
const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr;
A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag,
C_view);
A2C0(A_view, B_view, args.mr, range_mc, kc, scale, add, out_tag, C_MC_NR);
}
}
template <typename TB, typename TC>
template <typename TB, class CView>
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
const IndexRange& range_mc,
const IndexRangePartition& ranges_kc,
const IndexRange& range_nc, const MMArgs& args,
RowPtrs<TC> C) {
CView C_MC_NC) {
// Peel off the first iteration of the kc loop: avoid zero-initializing `C`
// by writing directly into it, and later accumulating into it.
ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C);
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C_MC_NC);
});
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C);
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C_MC_NC);
});
}
@ -585,15 +526,15 @@ class MMKernel {
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
// Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0.
// `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
// Updates a `kRowsAC x kNR` tile in `C_MC_NR` starting at row `imc`, column
// 0. `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
// relative to the C column.
template <size_t kRowsAC, /*deduced:*/ class Tag, class CView>
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
const StridedViewBF B_view, size_t imc,
size_t kc, const float scale,
const float* HWY_RESTRICT add, Tag tag,
CView C_view) {
CView C_MC_NR) {
const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
@ -777,7 +718,62 @@ class MMKernel {
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view);
horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_MC_NR);
}
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
// at row/col 0.
template <class Tag, class CView>
static HWY_INLINE void A2C0(const StridedViewBF A_view,
const StridedViewBF B_view, size_t mr,
const IndexRange& range_mc, size_t kc,
const float scale, const float* HWY_RESTRICT add,
Tag tag, CView C_MC_NR) {
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
const size_t mc = range_mc.Num();
size_t imc = 0;
// M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) {
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
}
return;
}
// AVX2 (16 registers)
if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) {
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
}
}
if (HWY_UNLIKELY(imc != mc)) {
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
}
return;
}
HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) {
LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
}
}
const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
imc += 2;
}
if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
imc += 1;
}
HWY_DASSERT(imc == mc);
}
};
@ -813,10 +809,10 @@ class MMImpl {
}
public:
static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N,
static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, size_t num_B,
size_t vector_bytes,
MatMulEnv::PerCluster& per_cluster) {
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N, num_B);
intptr_t index = IndexOfKey(key, per_cluster.keys);
// First time we see this shape/key.
if (HWY_UNLIKELY(index < 0)) {
@ -831,17 +827,19 @@ class MMImpl {
}
static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N,
double t0, MMAutoTune<MMConfig>& tuner,
size_t num_B, double t0,
MMAutoTune<MMConfig>& tuner,
const MMConfig& cfg) {
const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
hwy::platform::InvariantTicksPerSecond();
const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA
const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA
if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) {
fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9,
min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(),
StringFromOrder(cfg.Order()), cfg.InnerTasks());
fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n",
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(),
cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()),
cfg.InnerTasks());
}
if (HWY_UNLIKELY(env.print_best && tuner.Best())) {
const auto ratio = [&tuner](uint64_t ticks) -> double {
@ -849,12 +847,13 @@ class MMImpl {
static_cast<double>(tuner.BestTicks());
};
const MMConfig& best = *tuner.Best();
fprintf(stderr,
"\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n",
M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
best.KC(), best.NC(), StringFromOrder(best.Order()),
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
ratio(tuner.FirstConfigTicks()));
fprintf(
stderr,
"\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n",
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
best.KC(), best.NC(), StringFromOrder(best.Order()),
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
ratio(tuner.FirstConfigTicks()));
}
}
@ -874,10 +873,11 @@ class MMImpl {
class MMLoops {
public:
// Called from `MatMul` from two places: either with the next autotune config,
// or with the best config.
// or with the best config. `B2` is null unless called from `TwoMatMul`.
template <typename TB, typename TC>
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C, const MMArgs& args) {
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
PROFILER_ZONE3(args.env.ctx.profiler,
args.env.ctx.Worker(args.options.cluster_idx), zone);
@ -885,7 +885,7 @@ class MMLoops {
DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
Loop(order, parallel, A, B, C, args);
Loop(order, parallel, A, B, B2, C, args);
});
});
}
@ -901,18 +901,14 @@ class MMLoops {
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C, const MMArgs& args) {
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_M = args.ranges_mc.Range(0);
const IndexRange& range_K = args.ranges_kc.Range(0);
const size_t K = range_K.Num();
const StridedViewBF A_view = A.View(range_M.begin(), 0, K);
const size_t B_stride =
Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes);
const IndexRange& range_mc = args.ranges_mc.Range(0);
const IndexRange& range_kc = args.ranges_kc.Range(0);
// Similar to `B3A2C0`, but here we hoisted `A_view`.
parallel.ForN(
args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes),
args.inner_tasks, args.options.cluster_idx,
@ -920,26 +916,19 @@ class MMLoops {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS
const StridedViewBF B_storage_view(B_storage, K, B_stride);
MMKernel::B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(),
C.View(0, range_nc.begin(), range_nc.Num()));
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
const StridedViewBF B_view =
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
const RowPtrs<TC> C_view = C.View(range_M.begin(), row_b);
const float* HWY_RESTRICT add =
args.add ? args.add + row_b : nullptr;
const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);
MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add,
MMSetC(), C_view);
if (B2 != nullptr) {
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
MMSetC(), C2);
}
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_M, range_nc, C2, worker);
}
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
}
});
}
@ -948,7 +937,8 @@ class MMLoops {
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C, const MMArgs& args) {
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -959,14 +949,21 @@ class MMLoops {
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc,
range_nc, args, C);
MMKernel::ForeachKC(
A, B, range_mc, args.ranges_kc, range_nc, args,
C.View(0, range_nc.begin(), range_nc.Num()));
const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);
if (B2 != nullptr) {
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc,
range_nc, args, C2);
}
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_mc, range_nc, C2, worker);
}
args.options.MaybeCallFunc(C, range_mc, range_nc, C2,
worker);
}
});
}
@ -976,10 +973,11 @@ class MMLoops {
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C, const MMArgs& args) {
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_K = args.ranges_kc.Range(0);
const IndexRange& range_kc = args.ranges_kc.Range(0);
parallel.ForRangesMC_NC(
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
@ -987,14 +985,19 @@ class MMLoops {
size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(),
C);
MMKernel::B3A2C0(
A, B, range_mc, range_kc, range_nc, args, MMSetC(),
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));
const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);
if (B2 != nullptr) {
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
MMSetC(), C2);
}
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_mc, range_nc, C2, worker);
}
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
}
});
}
@ -1004,7 +1007,8 @@ class MMLoops {
template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C, const MMArgs& args) {
const MatPtrT<TB>* B2, RowPtrs<TC> C,
const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
parallel.ForRangesMC_NC(
@ -1013,14 +1017,20 @@ class MMLoops {
size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args,
C);
MMKernel::ForeachKC(
A, B, range_mc, args.ranges_kc, range_nc, args,
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));
const StridedViewBF C2 = args.env.C_tiles.C(
Extents2D(range_mc.Num(), range_nc.Num()), worker);
if (B2 != nullptr) {
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc,
args, C2);
}
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_mc, range_nc, C2, worker);
}
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
}
});
}
@ -1060,20 +1070,23 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const size_t M = A.Rows();
const size_t K = A.Cols();
const size_t N = B.Rows();
const size_t num_B = 1;
const CacheInfo& cache = env.ctx.cache_info;
MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(),
env.per_cluster[cluster_idx]);
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
// (Also auto-tunes, hence outside the timed section to prevent interference.)
const StridedViewBF A_view =
MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
MatPtrT<TB>* B2 = nullptr; // required for type matching
MMAutoTune<MMConfig>& tuner = per_key.autotune;
if (HWY_LIKELY(tuner.Best())) {
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
add, options, tuner, *tuner.Best());
MMLoops::Dispatch(A_view, B, C_rows, args);
const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner,
*tuner.Best());
MMLoops::Dispatch(A_view, B, B2, C_rows, args);
return &per_key;
}
@ -1082,20 +1095,83 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
// Ensure matrix dimensions match each other (off the hot path).
HWY_ASSERT(K == B.Cols());
HWY_ASSERT(M <= kMaxBatchSize);
HWY_ASSERT(K <= MMStorage::kMaxK);
HWY_ASSERT(K <= MMEntireA::kMaxK);
HWY_ASSERT(N % kNR == 0);
MMImpl::EnsureAligned(A, cache.VectorBytes());
tuner.SetCandidates(
MMCandidates(cache, M, K, N, sizeof(TC), env.print_config));
MMCandidates(cache, M, K, N, num_B, sizeof(TC), env.print_config));
}
const MMConfig& cfg = tuner.NextConfig();
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
add, options, tuner, cfg);
const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner, cfg);
const uint64_t t0 = hwy::timer::Start();
MMLoops::Dispatch(A_view, B, C_rows, args);
MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg);
MMLoops::Dispatch(A_view, B, B2, C_rows, args);
MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg);
return &per_key;
}
// Performs A*B1 and A*B2 in parallel. This is useful for gated FFNs.
// Differences vs MatMul: The second result matrix is not materialized, it is
// only passed to the `options.func` callback. There is no `add` argument
// because it is not required for this use case. There is no default `options`
// argument because `options.func` must be set by the caller.
template <typename TB>
HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
const MatPtrT<TB>& B2, MatMulEnv& env,
MatPtrT<BF16>& C, MMOptions options) {
static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul");
const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.
RowPtrs<BF16> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
const size_t M = A.Rows();
const size_t K = A.Cols();
const size_t N = B1.Rows();
const size_t num_B = 2;
const CacheInfo& cache = env.ctx.cache_info;
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
// (Also auto-tunes, hence outside the timed section to prevent interference.)
const StridedViewBF A_view(A, 0, 0, A.Cols());
MMAutoTune<MMConfig>& tuner = per_key.autotune;
if (HWY_LIKELY(tuner.Best())) {
// Only A scale - B1/B2 may differ, and are passed separately.
const MMArgs args(env, M, K, N, A.Scale(),
/*add=*/nullptr, options, tuner, *tuner.Best());
MMLoops::Dispatch(A_view, B1, &B2, C_rows, args);
return &per_key;
}
// Autotuning, first call: enumerate all feasible configs.
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
// Ensure matrix dimensions match each other (off the hot path).
HWY_ASSERT(K == B1.Cols());
HWY_ASSERT(K == B2.Cols());
HWY_ASSERT(M <= kMaxBatchSize);
HWY_ASSERT(K <= MMEntireA::kMaxK);
HWY_ASSERT(N % kNR == 0);
MMImpl::EnsureAligned(A, cache.VectorBytes());
tuner.SetCandidates(
MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config));
}
const MMConfig& cfg = tuner.NextConfig();
// Only A scale - B1/B2 may differ, and are passed separately.
const MMArgs args(env, M, K, N, A.Scale(), /*add=*/nullptr, options, tuner,
cfg);
const uint64_t t0 = hwy::timer::Start();
MMLoops::Dispatch(A_view, B1, &B2, C_rows, args);
MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg);
return &per_key;
}

View File

@ -63,11 +63,12 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
class GenerateCandidates {
public:
GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N,
size_t sizeof_TC, bool print_config)
size_t num_B, size_t sizeof_TC, bool print_config)
: cache_(cache),
M_(M),
K_(K),
N_(N),
num_B_(num_B),
sizeof_TC_(sizeof_TC),
// These influence kc/nc, but are also stored in `MMConfig` for
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
@ -150,7 +151,7 @@ class GenerateCandidates {
}
}
// The number of A and B columns to read between updating `partial`.
// The number of A and B columns to read between updating `C`.
SizeVec KC(size_t mr, MMOrder order) const {
// `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(M_, mr);
@ -164,7 +165,7 @@ class GenerateCandidates {
// TB=NUQ due to less amortization of the table loads. Due to the low L1
// latency, the packing is still effectively fused into `LoopKC`. It may
// be better to round up and accept a few L2 accesses in exchange for
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
// fewer loops over K, and thus fewer writes to `C`. Hence we do not
// subtract the output and buf, and allow using more than the actual L1
// size. This results in an overestimate, and the loop below will propose
// the next few smaller values for the autotuner to evaluate.
@ -179,7 +180,7 @@ class GenerateCandidates {
// Avoid proposing kc > K.
if (K_ > kc_multiple_) {
// Generally it is best to use the full `kc` (fewer writes to `partial`),
// Generally it is best to use the full `kc` (fewer writes to `C`),
// but a bit less can be better if it evenly divides `K`, or enables an
// `mc` that evenly divides `M`. Try several smaller values.
@ -196,7 +197,7 @@ class GenerateCandidates {
}
if (print_config_ && all_kc.size() > 1) {
fprintf(stderr, "KC: ");
fprintf(stderr, "num_B %zu: KC: ", num_B_);
for (size_t kc : all_kc) {
fprintf(stderr, "%zu ", kc);
}
@ -214,18 +215,18 @@ class GenerateCandidates {
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
// partial.
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` C rows.
const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes();
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, kMaxBatchSize);
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
HWY_DASSERT(mc_max != 0);
mc_max = HWY_MIN(mc_max, M_);
mc_max = hwy::RoundDownTo(mc_max, mr);
SizeVec all_mc(1, mc_max);
// Larger MC is better for non-blocks, otherwise we want more small options.
const size_t reps = !IsBlock(order) ? 2 : 3;
// Larger MC is better for non-blocks, otherwise we want more small options,
// especially for two B.
const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_);
size_t prev = mc_max;
for (size_t rep = 0; rep < reps; ++rep) {
@ -240,7 +241,7 @@ class GenerateCandidates {
}
if (print_config_ && all_mc.size() > 1) {
fprintf(stderr, "MC: ");
fprintf(stderr, "num_B %zu: MC: ", num_B_);
for (size_t mc : all_mc) {
fprintf(stderr, "%zu ", mc);
}
@ -252,14 +253,15 @@ class GenerateCandidates {
// The number of (possibly L3 resident) B rows per `NT_MT` task.
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
size_t nc_max = N_;
size_t nc_max = kMaxNC;
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
// Otherwise, leave it unbounded.
// such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise,
// leave it unbounded.
if (M_ > mr) {
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), N_);
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC);
}
nc_max = HWY_MIN(nc_max, N_);
HWY_DASSERT(nc_max != 0);
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
@ -278,7 +280,7 @@ class GenerateCandidates {
if (N_ > nc_multiple_) {
// Large L3, but its behavior and characteristics varies across platforms,
// hence autotune a wider range of nc than the other dimensions.
size_t reps = 10;
size_t reps = 9 + num_B_;
// For small M, we can afford larger NC, hence allow fewer small options.
if (M_ <= 2 * mr) reps -= 1;
@ -301,7 +303,7 @@ class GenerateCandidates {
}
if (print_config_ && all_nc.size() > 1) {
fprintf(stderr, "NC: ");
fprintf(stderr, "num_B %zu: NC: ", num_B_);
for (size_t nc : all_nc) {
fprintf(stderr, "%zu ", nc);
}
@ -329,6 +331,7 @@ class GenerateCandidates {
const size_t M_;
const size_t K_;
const size_t N_;
const size_t num_B_;
const size_t sizeof_TC_;
const size_t kc_multiple_;
@ -341,12 +344,13 @@ class GenerateCandidates {
// Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
size_t N, size_t sizeof_TC,
size_t N, size_t num_B, size_t sizeof_TC,
bool print_config) {
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)();
}
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) {
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
: ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) {
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
per_cluster.resize(num_clusters);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {

View File

@ -21,7 +21,6 @@
#include <stddef.h>
#include <stdint.h>
#include <functional>
#include <vector>
// IWYU pragma: begin_exports
@ -54,7 +53,9 @@ HWY_INLINE_VAR constexpr size_t kNR = 4;
// or less on ISAs with fewer registers, or for the last few rows of A.
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink?
// For `MMTilesC`.
HWY_INLINE_VAR constexpr size_t kMaxMC = 512;
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
@ -108,9 +109,9 @@ struct MMParallelWithinCluster {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t base = ctx.Worker(cluster_idx);
const IndexRangePartition worker_ranges = StaticPartition(
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(worker_ranges, cluster,
ParallelizeOneRange(ranges_n, cluster,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, base + worker);
});
@ -169,20 +170,20 @@ struct MMParallelHierarchical {
if (num_clusters == 1) {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const IndexRangePartition worker_ranges = StaticPartition(
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
return ParallelizeOneRange(
worker_ranges, cluster,
ranges_n, cluster,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker);
});
}
// Assign each cluster a sub-range of `range_n` (typically hundreds).
const IndexRangePartition n_ranges =
const IndexRangePartition ranges_n =
StaticPartition(range_n, num_clusters, n_multiple);
ParallelizeOneRange(
n_ranges, all_clusters,
ranges_n, all_clusters,
[&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base = ctx.Worker(cluster_idx);
@ -274,32 +275,51 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
// C is BF16/float.
void BindC(ThreadingContext& ctx, MatPtr& C);
// For A.
class MMStorage {
// Space for converting A=F32 to BF16 before the matmul. This is faster than
// on-the-fly when native BF16 is available: it only happens once, not per B
// tile row, and the cache footprint is smaller.
class MMEntireA {
public:
// Compile-time bounds on matrix columns to enable pre-allocating storage
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
static constexpr size_t kMaxK = 36 * 1024;
MMStorage(const Allocator& allocator)
explicit MMEntireA(const Allocator& allocator)
// 288 MiB. Must be padded, see `DoDecompressA`.
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
MatPadding::kOdd) {}
// Returns matrix view. Converting A=F32 to BF16 up-front is faster than
// on-the-fly when native BF16 is available: it only happens once, not per B
// tile row, and the cache footprint is smaller.
StridedViewBF A(const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxBatchSize);
HWY_DASSERT(extents.cols <= kMaxK);
return StridedViewBF(const_cast<BF16*>(A_.Row(0)), extents.cols,
A_.Stride());
return StridedViewBF(A_, 0, 0, extents.cols);
}
private:
MatStorageT<BF16> A_;
};
// One tile of C per *worker* (required for `kNT_MT*`).
class MMTilesC {
public:
explicit MMTilesC(const ThreadingContext& ctx) {
const size_t max_workers = ctx.pools.MaxWorkers();
C_.reserve(max_workers);
for (size_t worker = 0; worker < max_workers; ++worker) {
C_.push_back(MatStorageT<BF16>("Ctile", Extents2D(kMaxBatchSize, kMaxNC),
ctx.allocator, MatPadding::kOdd));
}
}
StridedViewBF C(const Extents2D& extents, size_t worker) const {
HWY_DASSERT(extents.rows <= kMaxBatchSize);
HWY_DASSERT(worker < C_.size());
return StridedViewBF(C_[worker], 0, 0, extents.cols);
}
private:
std::vector<MatStorageT<BF16>> C_;
};
//------------------------------------------------------------------------------
// Autotuning
@ -471,7 +491,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
size_t N, size_t sizeof_TC,
size_t N, size_t num_B, size_t sizeof_TC,
bool print_config);
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
@ -595,12 +615,14 @@ class MMKeys {
static constexpr Key kPadding = 0;
// Compresses the dimensions into a single Key for faster comparison.
static Key KeyFromDims(size_t M, size_t K, size_t N) {
static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) {
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
HWY_DASSERT(K < (Key{1} << 24));
HWY_DASSERT(N < (Key{1} << 24));
HWY_DASSERT(K < (Key{1} << 20));
HWY_DASSERT(N < (Key{1} << 20));
HWY_DASSERT(num_B == 1 || num_B == 2);
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
(static_cast<Key>(N) << 40);
(static_cast<Key>(N) << 40) |
(static_cast<Key>(num_B) << 60);
HWY_DASSERT(key != kPadding);
return key;
}
@ -643,10 +665,6 @@ class MMKeys {
// Per-MatMul-shape state.
struct MMPerKey {
// Only profile if enabled and the main autotuner finished. `autotune_par_a`
// might not be active if inputs are all BF16.
bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); }
MMAutoTune<MMConfig> autotune;
MMAutoTune<MMParA> autotune_par_a;
};
@ -666,12 +684,15 @@ struct MatMulEnv {
// Whether to print the best config immediately after autotuning finished.
bool print_best = false;
MMStorage storage;
MMEntireA A_BF;
MMTilesC C_tiles;
struct PerCluster {
MMKeys keys;
std::vector<MMPerKey> per_key;
HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing
// Prevents false sharing.
HWY_MAYBE_UNUSED uint8_t
padding[HWY_ALIGNMENT - sizeof(MMKeys) - sizeof(per_key)];
};
std::vector<PerCluster> per_cluster;
@ -687,31 +708,57 @@ struct MatMulEnv {
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
};
// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols)
// that this thread has just filled, a view into a second tile (only for the
// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`).
using MMFused = std::function<void(RowPtrsBF, IndexRange, IndexRange,
StridedViewBF, size_t)>;
// Called via `CallClosure`, which consumes the first (opaque) argument. User
// functions are called with the entire C matrix, the sub-ranges of M (rows)
// and N (cols) that this thread has just filled, a view into a second tile
// (only for `TwoMatmul`), and the worker thread index (see `ParallelFor`).
typedef void (*MMFunc)(const void* opaque, RowPtrsBF, IndexRange, IndexRange,
StridedViewBF, size_t);
class MMOptions {
// Same technique as in `hwy::ThreadPool` and C++23 `std::function_ref`:
// type-erasure without allocation.
template <class Closure>
static void CallClosure(const void* opaque, RowPtrsBF C1, IndexRange range_r,
IndexRange range_c, StridedViewBF C2, size_t worker) {
(*reinterpret_cast<const Closure*>(opaque))(C1, range_r, range_c, C2,
worker);
}
public:
// `closure` must remain alive until the end of (Two)MatMul.
template <class Closure>
void SetFunc(const Closure& closure) {
func = static_cast<MMFunc>(&CallClosure<Closure>);
opaque = &closure;
}
void MaybeCallFunc(RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) const {
if (func != nullptr) {
func(opaque, C1, range_r, range_c, C2, worker);
}
}
MMFunc func = nullptr; // called if non-null and `TC` is BF16.
const void* opaque = nullptr;
struct MMOptions {
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
MMFused fused; // called if non-null and `TC` is BF16.
};
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
// register pressure compared to individual values/references. Also used for
// passing through `DispatchOrder`.
struct MMArgs {
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale,
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, float scale_A,
const float* HWY_RESTRICT add, MMOptions options,
const MMAutoTune<MMConfig>& autotune, const MMConfig& config)
: env(env),
line_bytes(env.ctx.cache_info.LineBytes()),
range_n(0, N),
scale(scale),
scale_A(scale_A),
add(add),
options(options),
@ -728,7 +775,8 @@ struct MMArgs {
// MatMul arguments:
const IndexRange range_n; // entire N
const double scale;
// There can be two B, so do not yet multiply together the A and B scales.
const float scale_A;
const float* HWY_RESTRICT add;
const MMOptions options;

View File

@ -53,6 +53,14 @@ namespace HWY_NAMESPACE {
// included from matmul_static_*.cc.
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT
HWY_MAYBE_UNUSED void TwoMatMulStatic(const MatPtrT<BF16>& A, // NOLINT
const MatPtrT<GEMMA_MATMUL_TB>& B1,
const MatPtrT<GEMMA_MATMUL_TB>& B2,
MatMulEnv& env, MatPtrT<BF16>& C,
MMOptions options) {
TwoMatMul(A, B1, B2, env, C, options);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

View File

@ -37,13 +37,19 @@
const float* HWY_RESTRICT add, MatMulEnv& env, \
MatPtrT<TC>& C, MMOptions options);
#define GEMMA_MATMUL_FOR_B(TB) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, TB) \
void TwoMatMulStatic(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1, \
const MatPtrT<TB>& B2, MatMulEnv& env, \
MatPtrT<BF16>& C, MMOptions options);
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
namespace NAMESPACE { \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \
GEMMA_MATMUL_FOR_B(BF16) \
GEMMA_MATMUL_FOR_B(float) \
GEMMA_MATMUL_FOR_B(NuqStream) \
GEMMA_MATMUL_FOR_B(SfpStream) \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -29,6 +29,8 @@
#include <stddef.h>
#include <stdio.h>
#include <atomic>
#include "ops/matmul.h"
#include "util/basics.h"
#include "util/mat.h"
@ -246,7 +248,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
MatPadding::kOdd);
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
MatStorageT<TC> C2("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
C.AllocateAndAttachRowPtrs(env.row_ptrs);
C2.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
@ -262,7 +266,48 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
AssertClose(A, BT, C_slow, C, env, line);
if (per_key->autotune.Best()) break;
// Check before TwoMatMulStatic(), which can invalidate per_key.
const bool autotune_done = !!per_key->autotune.Best();
// Ensure the tiled view returns the same result as C.
if constexpr (IsBF16<TA>() && IsBF16<TC>()) {
// The total view area should match the entire C matrix.
std::atomic<size_t> total_view_area = 0;
const auto fused = [&](RowPtrsBF C2_rows, IndexRange range_r,
IndexRange range_c, StridedViewBF C2_view,
size_t worker) {
total_view_area.fetch_add(range_r.Num() * range_c.Num());
HWY_ASSERT(range_c.Num() <= C2_view.Cols());
HWY_ASSERT(worker < env.ctx.pools.MaxWorkers());
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
const size_t r = range_r.begin() + ir;
for (size_t ic = 0; ic < range_c.Num(); ++ic) {
const size_t c = range_c.begin() + ic;
const float expected =
hwy::ConvertScalarTo<float>(C2_rows.Row(r)[c]);
const float actual =
hwy::ConvertScalarTo<float>(C2_view.Row(ir)[ic]);
const float L1 = hwy::ScalarAbs(actual - expected);
if (L1 > 1E-6f) {
HWY_ABORT("%zu: ir %zu ic %zu L1 %f expected %f actual %f.",
worker, ir, ic, L1, expected, actual);
}
}
}
};
options.SetFunc(fused);
TwoMatMulStatic(A, BT, BT, env, C2, options);
HWY_ASSERT_EQ(C.Extents().Area(), total_view_area.load());
options.func = nullptr; // reset for next call
// TwoMatMulStatic() does not support adding a bias vector.
if (!add) {
AssertClose(A, BT, C, C2, env, line);
}
}
if (autotune_done) break;
}
}

View File

@ -69,6 +69,14 @@ MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
});
}
static inline void CallTwoMatMul(const MatPtrT<BF16>& A, const MatPtr& B1,
const MatPtr& B2, MatMulEnv& env,
MatPtrT<BF16>& C, const MMOptions& options) {
return CallUpcastedSame(&B1, &B2, [&](const auto* B1_t, const auto* B2_t) {
return TwoMatMulStatic(A, *B1_t, *B2_t, env, C, options);
});
}
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
// casting prob from float to double just makes some changes to the
// exponent bias and pads zeros in the mantissa.

View File

@ -40,10 +40,11 @@ class RowPtrs {
public:
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
RowPtrs View(size_t r, size_t c) {
// Extra argument is for compatibility with `StridedView`.
RowPtrs View(size_t r, size_t c, size_t /*cols*/) {
RowPtrs<T> view(row_ptrs_);
view.r0_ = static_cast<uint32_t>(r);
view.c0_ = static_cast<uint32_t>(c);
view.r0_ = static_cast<uint32_t>(r0_ + r);
view.c0_ = static_cast<uint32_t>(c0_ + c);
return view;
}
@ -531,7 +532,11 @@ class StridedView {
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
HWY_DASSERT(stride >= cols);
if constexpr (HWY_IS_DEBUG_BUILD) {
if (stride < cols) {
HWY_ABORT("stride %zu < cols %zu", stride, cols);
}
}
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.