diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 3ffa858..ff81671 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -37,7 +37,6 @@ class GemmaBatchBench : public ::testing::Test { protected: std::vector BatchGemmaReply( const std::vector& inputs) { - s_env->SetMaxGeneratedTokens(24); s_env->MutableConfig().temperature = 0.0f; // deterministic s_env->MutableConfig().verbosity = 2; std::vector replies; @@ -92,15 +91,18 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { inputs.push_back(questions[qpos++]); if (qpos == questions.size()) qpos = 0; } + s_env->SetMaxGeneratedTokens(24); std::vector responses = BatchGemmaReply(inputs); for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); ++i) { fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } + PROFILER_PRINT_RESULTS(); + // Run again: prefill will be faster due to autotuning. Fewer decode steps // because those are already fast. - s_env->SetMaxGeneratedTokens(3); + s_env->SetMaxGeneratedTokens(2); responses = BatchGemmaReply(inputs); PROFILER_PRINT_RESULTS(); diff --git a/gemma/configs.h b/gemma/configs.h index e02645b..275f374 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -36,6 +36,10 @@ HWY_INLINE_VAR constexpr int kAttentionUseOld = 2; HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024; +#ifndef GEMMA_FUSED_FFN +#define GEMMA_FUSED_FFN 1 +#endif // !GEMMA_FUSED_FFN + // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class PromptWrapping { GEMMA_IT, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index cb7ae6a..a7f1b01 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -43,6 +43,7 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +// For use by Vit even if !GEMMA_FUSED_FFN. template void Activation(ActivationType activation, T1* HWY_RESTRICT c1, const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p, @@ -64,7 +65,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1, }); } -// No C2 multiplier. +// No C2 multiplier - used by Vit. template void ActivationBatched( ActivationType activation, Mat& c1, ThreadingContext& ctx, @@ -80,6 +81,34 @@ void ActivationBatched( }); } +#if GEMMA_FUSED_FFN + +// Called during `TwoMatMul`. +static inline void Activation(ActivationType activation, const RowPtrsBF C1, + const IndexRange range_r, + const IndexRange range_c, const StridedViewBF C2, + hwy::Profiler& p, const size_t worker) { + static const auto zone = p.AddZone("Gen.ActivationFused"); + PROFILER_ZONE3(p, worker, zone); + + const size_t cols = range_c.Num(); + HWY_DASSERT(C2.Cols() == cols); + + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + // ActivationType::Gelu + // Gated: Gelu(c1) * c2. + for (size_t ir = 0; ir < range_r.Num(); ++ir) { + Decompress1AndCompressInplace( + DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir), + [](DF df, VF v1, VF v2) + HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); + } +} + +#else + template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, @@ -102,6 +131,8 @@ HWY_NOINLINE void ActivationBatched( } } +#endif // GEMMA_FUSED_FFN + template HWY_NOINLINE void ResidualConnection(const MatPtrT& other, MatPtrT& HWY_RESTRICT x, @@ -126,28 +157,32 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive); PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); const LayerConfig& layer_config = layer.layer_config; - const size_t ffh_hidden_dim = layer_config.ff_hidden_dim; - const bool add_bias = layer_config.ff_biases; - const float* bias1 = - add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr; - const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr; - const float* output_bias = - add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr; + HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. +#if GEMMA_FUSED_FFN + const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, + StridedViewBF C2, size_t worker) { + Activation(layer_config.activation, C1, range_r, range_c, C2, + env.ctx.profiler, worker); + }; + MMOptions options; + options.SetFunc(fused); + CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, + layer.gating_einsum_w2, env, activations.C1, options); +#else // Compute the hidden layer activations. - CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env, + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env, activations.C1); - CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env, + CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env, activations.C2); - // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_config.activation, activations.C1, &activations.C2, env.ctx); +#endif // Hidden layer -> output layer. - CallMatMul(activations.C1, layer.linear_w, output_bias, env, - activations.ffw_out); + CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 3a20690..8957f4c 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -155,14 +155,14 @@ class MMStoreHorizontalSumsIntoC { template , class Tag, class CView> HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, const float scale, const float* HWY_RESTRICT add, - const size_t imc, Tag tag, CView C_rows) const { + const size_t imc, Tag tag, CView C_MC_NR) const { const V4 vscale = hn::Set(d4, scale); HWY_ALIGN static constexpr float kZero[4] = {}; const V4 vadd = hn::Load(d4, add ? add : kZero); - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_MC_NR); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_MC_NR); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_MC_NR); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_MC_NR); } private: @@ -202,10 +202,10 @@ class MMStoreHorizontalSumsIntoC { class Tag, class CView> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, VF4 vadd, Tag, const size_t imc, - CView C_view) { + CView C_MC_NR) { if constexpr (kRow < kRowsAC) { - using TC = hwy::RemoveCvRef; - TC* HWY_RESTRICT pos = C_view.Row(imc + kRow); + using TC = hwy::RemoveCvRef; + TC* HWY_RESTRICT pos = C_MC_NR.Row(imc + kRow); const hn::Rebind dc4; if constexpr (hwy::IsSame()) { vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value @@ -268,9 +268,9 @@ class MMDecompress { } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. We also only - // have a single MMStorage. + // have a single MMEntireA. HWY_ASSERT(options.cluster_idx == 0); - const StridedViewBF A_view = env.storage.A(A.Extents()); + const StridedViewBF A_view = env.A_BF.A(A.Extents()); AutotuneDecompressA(A, A_view, autotune, env, options); return A_view; } @@ -387,111 +387,52 @@ class MMDecompress { // Stateless, wraps member functions. Contains the innermost 2-4 loops. class MMKernel { - // Compute size of per-worker storage for `kNR` row ranges of B. Stack - // allocation avoids passing a worker index. - static constexpr size_t B_stride_max = - kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); - public: - // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is - // `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start - // at row/col 0. `CView` is either `RowPtrs` or `StridedView`. - // Called by B3A2C0 and by callers that hoist `A_view`. - template - static HWY_INLINE void A2C0(const StridedViewBF A_view, - const StridedViewBF B_view, size_t mr, - const IndexRange& range_mc, size_t kc, - const float scale, const float* HWY_RESTRICT add, - Tag tag, CView C_view) { - HWY_DASSERT(1 <= mr && mr <= kMaxMR); - - const size_t mc = range_mc.Num(); - size_t imc = 0; - - // M == 1, or x86 with 8 SIMD registers: - if (HWY_UNLIKELY(mr == 1)) { - for (; imc < mc; ++imc) { - LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - return; - } - - // AVX2 (16 registers) - if (HWY_UNLIKELY(mr == 2)) { - if (HWY_LIKELY(mc >= 2)) { - for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - } - if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - return; - } - - HWY_DASSERT(mr == 4); - if (HWY_LIKELY(mc >= 4)) { - for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view); - } - } - const size_t remainder_mc = mc - imc; - HWY_DASSERT(remainder_mc < 4); - if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); - imc += 2; - } - if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); - imc += 1; - } - HWY_DASSERT(imc == mc); - } - - static constexpr size_t B_storage_max = kNR * B_stride_max; - // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads - // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by - // `ForeachKC` and when there is only a single KC task. - template + // `mc x kc` of A, `nc x kc` of B, and updates the `mc x nc` `C_MC_NC`. + // `CView` is either `RowPtrs` or `StridedView`. + template static void B3A2C0(const StridedViewBF A, const MatPtrT& B, const IndexRange& range_mc, const IndexRange& range_kc, const IndexRange& range_nc, const MMArgs& args, - Tag out_tag, RowPtrs C) { - HWY_ALIGN BF16 B_storage[B_storage_max]; - + Tag out_tag, CView C_MC_NC) { const size_t kc = range_kc.Num(); const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); + // Upper bound on per-worker storage for `kNR` row ranges of B. Stack + // allocation avoids passing a worker index. + constexpr size_t B_stride_max = + kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); + HWY_ALIGN BF16 B_storage[kNR * B_stride_max]; const size_t B_stride = Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes); const StridedViewBF B_storage_view(B_storage, kc, B_stride); - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { + const float scale = args.scale_A * B.Scale(); + for (size_t inc = 0; inc < range_nc.Num(); inc += kNR) { + // For `add` and `B`, which are global, unlike `C_MC_NC`. + const size_t row_b = range_nc.begin() + inc; const StridedViewBF B_view = MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view); - const RowPtrs C_view = C.View(range_mc.begin(), row_b); + const CView C_MC_NR = C_MC_NC.View(0, inc, kNR); const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr; - A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag, - C_view); + A2C0(A_view, B_view, args.mr, range_mc, kc, scale, add, out_tag, C_MC_NR); } } - template + template static void ForeachKC(const StridedViewBF A, const MatPtrT& B, const IndexRange& range_mc, const IndexRangePartition& ranges_kc, const IndexRange& range_nc, const MMArgs& args, - RowPtrs C) { + CView C_MC_NC) { // Peel off the first iteration of the kc loop: avoid zero-initializing `C` // by writing directly into it, and later accumulating into it. ranges_kc.VisitFirst([&](const IndexRange& range_kc) { - B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C_MC_NC); }); ranges_kc.VisitRemaining([&](const IndexRange& range_kc) { - B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C_MC_NC); }); } @@ -585,15 +526,15 @@ class MMKernel { // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). - // Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0. - // `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also + // Updates a `kRowsAC x kNR` tile in `C_MC_NR` starting at row `imc`, column + // 0. `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also // relative to the C column. template static HWY_INLINE void LoopKC(const StridedViewBF A_view, const StridedViewBF B_view, size_t imc, size_t kc, const float scale, const float* HWY_RESTRICT add, Tag tag, - CView C_view) { + CView C_MC_NR) { const hn::ScalableTag dbf; using VBF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); @@ -777,7 +718,62 @@ class MMKernel { hn::Vec sum0, sum1, sum2, sum3; horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3); - horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view); + horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_MC_NR); + } + + // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. + // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is + // `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start + // at row/col 0. + template + static HWY_INLINE void A2C0(const StridedViewBF A_view, + const StridedViewBF B_view, size_t mr, + const IndexRange& range_mc, size_t kc, + const float scale, const float* HWY_RESTRICT add, + Tag tag, CView C_MC_NR) { + HWY_DASSERT(1 <= mr && mr <= kMaxMR); + + const size_t mc = range_mc.Num(); + size_t imc = 0; + + // M == 1, or x86 with 8 SIMD registers: + if (HWY_UNLIKELY(mr == 1)) { + for (; imc < mc; ++imc) { + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + return; + } + + // AVX2 (16 registers) + if (HWY_UNLIKELY(mr == 2)) { + if (HWY_LIKELY(mc >= 2)) { + for (; imc <= mc - 2; imc += 2) { + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + } + if (HWY_UNLIKELY(imc != mc)) { + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + return; + } + + HWY_DASSERT(mr == 4); + if (HWY_LIKELY(mc >= 4)) { + for (; imc <= mc - 4; imc += 4) { + LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + } + } + const size_t remainder_mc = mc - imc; + HWY_DASSERT(remainder_mc < 4); + if (HWY_UNLIKELY(remainder_mc & 2)) { + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + imc += 2; + } + if (HWY_UNLIKELY(remainder_mc & 1)) { + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR); + imc += 1; + } + HWY_DASSERT(imc == mc); } }; @@ -813,10 +809,10 @@ class MMImpl { } public: - static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, + static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, size_t num_B, size_t vector_bytes, MatMulEnv::PerCluster& per_cluster) { - const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); + const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N, num_B); intptr_t index = IndexOfKey(key, per_cluster.keys); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { @@ -831,17 +827,19 @@ class MMImpl { } static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N, - double t0, MMAutoTune& tuner, + size_t num_B, double t0, + MMAutoTune& tuner, const MMConfig& cfg) { const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / hwy::platform::InvariantTicksPerSecond(); - const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA + const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) { - fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9, - min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(), - StringFromOrder(cfg.Order()), cfg.InnerTasks()); + fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", + M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(), + cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()), + cfg.InnerTasks()); } if (HWY_UNLIKELY(env.print_best && tuner.Best())) { const auto ratio = [&tuner](uint64_t ticks) -> double { @@ -849,12 +847,13 @@ class MMImpl { static_cast(tuner.BestTicks()); }; const MMConfig& best = *tuner.Best(); - fprintf(stderr, - "\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", - M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), - best.KC(), best.NC(), StringFromOrder(best.Order()), - best.InnerTasks(), ratio(tuner.WorstMinTicks()), - ratio(tuner.FirstConfigTicks())); + fprintf( + stderr, + "\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n", + M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(), + best.KC(), best.NC(), StringFromOrder(best.Order()), + best.InnerTasks(), ratio(tuner.WorstMinTicks()), + ratio(tuner.FirstConfigTicks())); } } @@ -874,10 +873,11 @@ class MMImpl { class MMLoops { public: // Called from `MatMul` from two places: either with the next autotune config, - // or with the best config. + // or with the best config. `B2` is null unless called from `TwoMatMul`. template static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); PROFILER_ZONE3(args.env.ctx.profiler, args.env.ctx.Worker(args.options.cluster_idx), zone); @@ -885,7 +885,7 @@ class MMLoops { DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { DispatchOrder(args.order, [&](const auto& order) HWY_ATTR { - Loop(order, parallel, A, B, C, args); + Loop(order, parallel, A, B, B2, C, args); }); }); } @@ -901,18 +901,14 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); - const IndexRange& range_M = args.ranges_mc.Range(0); - const IndexRange& range_K = args.ranges_kc.Range(0); - const size_t K = range_K.Num(); - const StridedViewBF A_view = A.View(range_M.begin(), 0, K); - const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes); + const IndexRange& range_mc = args.ranges_mc.Range(0); + const IndexRange& range_kc = args.ranges_kc.Range(0); - // Similar to `B3A2C0`, but here we hoisted `A_view`. parallel.ForN( args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks, args.options.cluster_idx, @@ -920,26 +916,19 @@ class MMLoops { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS - const StridedViewBF B_storage_view(B_storage, K, B_stride); + MMKernel::B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), + C.View(0, range_nc.begin(), range_nc.Num())); - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - const StridedViewBF B_view = - MMDecompress::DecompressB(B, row_b, range_K, B_storage_view); - const RowPtrs C_view = C.View(range_M.begin(), row_b); - const float* HWY_RESTRICT add = - args.add ? args.add + row_b : nullptr; + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); - MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add, - MMSetC(), C_view); + if (B2 != nullptr) { + MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args, + MMSetC(), C2); } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_M, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); } }); } @@ -948,7 +937,8 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -959,14 +949,21 @@ class MMLoops { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, - range_nc, args, C); + MMKernel::ForeachKC( + A, B, range_mc, args.ranges_kc, range_nc, args, + C.View(0, range_nc.begin(), range_nc.Num())); + + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, + range_nc, args, C2); + } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_mc, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, + worker); } }); } @@ -976,10 +973,11 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); - const IndexRange& range_K = args.ranges_kc.Range(0); + const IndexRange& range_kc = args.ranges_kc.Range(0); parallel.ForRangesMC_NC( args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, @@ -987,14 +985,19 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(), - C); + MMKernel::B3A2C0( + A, B, range_mc, range_kc, range_nc, args, MMSetC(), + C.View(range_mc.begin(), range_nc.begin(), range_nc.Num())); + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args, + MMSetC(), C2); + } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_mc, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); } }); } @@ -1004,7 +1007,8 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C, const MMArgs& args) { + const MatPtrT* B2, RowPtrs C, + const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); parallel.ForRangesMC_NC( @@ -1013,14 +1017,20 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args, - C); + MMKernel::ForeachKC( + A, B, range_mc, args.ranges_kc, range_nc, args, + C.View(range_mc.begin(), range_nc.begin(), range_nc.Num())); + + const StridedViewBF C2 = args.env.C_tiles.C( + Extents2D(range_mc.Num(), range_nc.Num()), worker); + + if (B2 != nullptr) { + MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc, + args, C2); + } if constexpr (IsBF16()) { - if (args.options.fused) { - StridedViewBF C2(nullptr, 0, 0); - args.options.fused(C, range_mc, range_nc, C2, worker); - } + args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker); } }); } @@ -1060,20 +1070,23 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); + const size_t num_B = 1; const CacheInfo& cache = env.ctx.cache_info; - MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(), - env.per_cluster[cluster_idx]); + MMPerKey& per_key = MMImpl::FindOrAddPerKey( + M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options); + MatPtrT* B2 = nullptr; // required for type matching + MMAutoTune& tuner = per_key.autotune; if (HWY_LIKELY(tuner.Best())) { - const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), - add, options, tuner, *tuner.Best()); - MMLoops::Dispatch(A_view, B, C_rows, args); + const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner, + *tuner.Best()); + MMLoops::Dispatch(A_view, B, B2, C_rows, args); return &per_key; } @@ -1082,20 +1095,83 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, // Ensure matrix dimensions match each other (off the hot path). HWY_ASSERT(K == B.Cols()); HWY_ASSERT(M <= kMaxBatchSize); - HWY_ASSERT(K <= MMStorage::kMaxK); + HWY_ASSERT(K <= MMEntireA::kMaxK); HWY_ASSERT(N % kNR == 0); MMImpl::EnsureAligned(A, cache.VectorBytes()); tuner.SetCandidates( - MMCandidates(cache, M, K, N, sizeof(TC), env.print_config)); + MMCandidates(cache, M, K, N, num_B, sizeof(TC), env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); - const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), - add, options, tuner, cfg); + const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner, cfg); const uint64_t t0 = hwy::timer::Start(); - MMLoops::Dispatch(A_view, B, C_rows, args); - MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg); + MMLoops::Dispatch(A_view, B, B2, C_rows, args); + MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg); + + return &per_key; +} + +// Performs A*B1 and A*B2 in parallel. This is useful for gated FFNs. +// Differences vs MatMul: The second result matrix is not materialized, it is +// only passed to the `options.func` callback. There is no `add` argument +// because it is not required for this use case. There is no default `options` +// argument because `options.func` must be set by the caller. +template +HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT& A, const MatPtrT& B1, + const MatPtrT& B2, MatMulEnv& env, + MatPtrT& C, MMOptions options) { + static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul"); + const size_t cluster_idx = options.cluster_idx; + HWY_DASSERT(cluster_idx < env.row_ptrs.size()); + PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); + + HWY_DASSERT(options.func != nullptr); // no other way to get access to C2. + + RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); + + const size_t M = A.Rows(); + const size_t K = A.Cols(); + const size_t N = B1.Rows(); + const size_t num_B = 2; + + const CacheInfo& cache = env.ctx.cache_info; + MMPerKey& per_key = MMImpl::FindOrAddPerKey( + M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); + + // (Also auto-tunes, hence outside the timed section to prevent interference.) + const StridedViewBF A_view(A, 0, 0, A.Cols()); + + MMAutoTune& tuner = per_key.autotune; + if (HWY_LIKELY(tuner.Best())) { + // Only A scale - B1/B2 may differ, and are passed separately. + const MMArgs args(env, M, K, N, A.Scale(), + /*add=*/nullptr, options, tuner, *tuner.Best()); + MMLoops::Dispatch(A_view, B1, &B2, C_rows, args); + return &per_key; + } + + // Autotuning, first call: enumerate all feasible configs. + if (HWY_UNLIKELY(!tuner.HasCandidates())) { + // Ensure matrix dimensions match each other (off the hot path). + HWY_ASSERT(K == B1.Cols()); + HWY_ASSERT(K == B2.Cols()); + HWY_ASSERT(M <= kMaxBatchSize); + HWY_ASSERT(K <= MMEntireA::kMaxK); + HWY_ASSERT(N % kNR == 0); + MMImpl::EnsureAligned(A, cache.VectorBytes()); + tuner.SetCandidates( + MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config)); + } + + const MMConfig& cfg = tuner.NextConfig(); + // Only A scale - B1/B2 may differ, and are passed separately. + const MMArgs args(env, M, K, N, A.Scale(), /*add=*/nullptr, options, tuner, + cfg); + + const uint64_t t0 = hwy::timer::Start(); + MMLoops::Dispatch(A_view, B1, &B2, C_rows, args); + MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg); return &per_key; } diff --git a/ops/matmul.cc b/ops/matmul.cc index 66ce0df..6ef1412 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -63,11 +63,12 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, class GenerateCandidates { public: GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N, - size_t sizeof_TC, bool print_config) + size_t num_B, size_t sizeof_TC, bool print_config) : cache_(cache), M_(M), K_(K), N_(N), + num_B_(num_B), sizeof_TC_(sizeof_TC), // These influence kc/nc, but are also stored in `MMConfig` for // `RangesOf*`. Must be a vector multiple. The previous/next cache line @@ -150,7 +151,7 @@ class GenerateCandidates { } } - // The number of A and B columns to read between updating `partial`. + // The number of A and B columns to read between updating `C`. SizeVec KC(size_t mr, MMOrder order) const { // `LoopKC` handles up to `mr` rows of A. const size_t rows_a = HWY_MIN(M_, mr); @@ -164,7 +165,7 @@ class GenerateCandidates { // TB=NUQ due to less amortization of the table loads. Due to the low L1 // latency, the packing is still effectively fused into `LoopKC`. It may // be better to round up and accept a few L2 accesses in exchange for - // fewer loops over K, and thus fewer writes to `partial`. Hence we do not + // fewer loops over K, and thus fewer writes to `C`. Hence we do not // subtract the output and buf, and allow using more than the actual L1 // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. @@ -179,7 +180,7 @@ class GenerateCandidates { // Avoid proposing kc > K. if (K_ > kc_multiple_) { - // Generally it is best to use the full `kc` (fewer writes to `partial`), + // Generally it is best to use the full `kc` (fewer writes to `C`), // but a bit less can be better if it evenly divides `K`, or enables an // `mc` that evenly divides `M`. Try several smaller values. @@ -196,7 +197,7 @@ class GenerateCandidates { } if (print_config_ && all_kc.size() > 1) { - fprintf(stderr, "KC: "); + fprintf(stderr, "num_B %zu: KC: ", num_B_); for (size_t kc : all_kc) { fprintf(stderr, "%zu ", kc); } @@ -214,18 +215,18 @@ class GenerateCandidates { // Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the // packed B. We want `mc * kc` elements of A to fit in L2, alongside - // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of - // partial. + // `bytes_b` plus `mc` cache lines because resident-A updates `mc` C rows. const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes(); size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc); - mc_max = HWY_MIN(mc_max, kMaxBatchSize); + mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC)); HWY_DASSERT(mc_max != 0); mc_max = HWY_MIN(mc_max, M_); mc_max = hwy::RoundDownTo(mc_max, mr); SizeVec all_mc(1, mc_max); - // Larger MC is better for non-blocks, otherwise we want more small options. - const size_t reps = !IsBlock(order) ? 2 : 3; + // Larger MC is better for non-blocks, otherwise we want more small options, + // especially for two B. + const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_); size_t prev = mc_max; for (size_t rep = 0; rep < reps; ++rep) { @@ -240,7 +241,7 @@ class GenerateCandidates { } if (print_config_ && all_mc.size() > 1) { - fprintf(stderr, "MC: "); + fprintf(stderr, "num_B %zu: MC: ", num_B_); for (size_t mc : all_mc) { fprintf(stderr, "%zu ", mc); } @@ -252,14 +253,15 @@ class GenerateCandidates { // The number of (possibly L3 resident) B rows per `NT_MT` task. SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { - size_t nc_max = N_; + size_t nc_max = kMaxNC; // Only if there will be reuse of B: choose the largest `nc_max` (C cols) - // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. - // Otherwise, leave it unbounded. + // such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise, + // leave it unbounded. if (M_ > mr) { const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); - nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), N_); + nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC); } + nc_max = HWY_MIN(nc_max, N_); HWY_DASSERT(nc_max != 0); nc_max = RoundDownWithFloor(nc_max, nc_multiple_); @@ -278,7 +280,7 @@ class GenerateCandidates { if (N_ > nc_multiple_) { // Large L3, but its behavior and characteristics varies across platforms, // hence autotune a wider range of nc than the other dimensions. - size_t reps = 10; + size_t reps = 9 + num_B_; // For small M, we can afford larger NC, hence allow fewer small options. if (M_ <= 2 * mr) reps -= 1; @@ -301,7 +303,7 @@ class GenerateCandidates { } if (print_config_ && all_nc.size() > 1) { - fprintf(stderr, "NC: "); + fprintf(stderr, "num_B %zu: NC: ", num_B_); for (size_t nc : all_nc) { fprintf(stderr, "%zu ", nc); } @@ -329,6 +331,7 @@ class GenerateCandidates { const size_t M_; const size_t K_; const size_t N_; + const size_t num_B_; const size_t sizeof_TC_; const size_t kc_multiple_; @@ -341,12 +344,13 @@ class GenerateCandidates { // Facade to avoid exposing `GenerateCandidates` in the header. std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, - size_t N, size_t sizeof_TC, + size_t N, size_t num_B, size_t sizeof_TC, bool print_config) { - return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)(); + return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)(); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) { +MatMulEnv::MatMulEnv(ThreadingContext& ctx) + : ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) { const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); per_cluster.resize(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { diff --git a/ops/matmul.h b/ops/matmul.h index 93e7b04..bedee3d 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -21,7 +21,6 @@ #include #include -#include #include // IWYU pragma: begin_exports @@ -54,7 +53,9 @@ HWY_INLINE_VAR constexpr size_t kNR = 4; // or less on ISAs with fewer registers, or for the last few rows of A. HWY_INLINE_VAR constexpr size_t kMaxMR = 4; -HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink? +// For `MMTilesC`. +HWY_INLINE_VAR constexpr size_t kMaxMC = 512; +HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. @@ -108,9 +109,9 @@ struct MMParallelWithinCluster { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t base = ctx.Worker(cluster_idx); - const IndexRangePartition worker_ranges = StaticPartition( + const IndexRangePartition ranges_n = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); - ParallelizeOneRange(worker_ranges, cluster, + ParallelizeOneRange(ranges_n, cluster, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, base + worker); }); @@ -169,20 +170,20 @@ struct MMParallelHierarchical { if (num_clusters == 1) { const size_t cluster_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const IndexRangePartition worker_ranges = StaticPartition( + const IndexRangePartition ranges_n = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); return ParallelizeOneRange( - worker_ranges, cluster, + ranges_n, cluster, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, worker); }); } // Assign each cluster a sub-range of `range_n` (typically hundreds). - const IndexRangePartition n_ranges = + const IndexRangePartition ranges_n = StaticPartition(range_n, num_clusters, n_multiple); ParallelizeOneRange( - n_ranges, all_clusters, + ranges_n, all_clusters, [&](const IndexRange& n_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = ctx.Worker(cluster_idx); @@ -274,32 +275,51 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); // C is BF16/float. void BindC(ThreadingContext& ctx, MatPtr& C); -// For A. -class MMStorage { +// Space for converting A=F32 to BF16 before the matmul. This is faster than +// on-the-fly when native BF16 is available: it only happens once, not per B +// tile row, and the cache footprint is smaller. +class MMEntireA { public: // Compile-time bounds on matrix columns to enable pre-allocating storage // and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B. static constexpr size_t kMaxK = 36 * 1024; - MMStorage(const Allocator& allocator) + explicit MMEntireA(const Allocator& allocator) // 288 MiB. Must be padded, see `DoDecompressA`. : A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd) {} - // Returns matrix view. Converting A=F32 to BF16 up-front is faster than - // on-the-fly when native BF16 is available: it only happens once, not per B - // tile row, and the cache footprint is smaller. StridedViewBF A(const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxBatchSize); - HWY_DASSERT(extents.cols <= kMaxK); - return StridedViewBF(const_cast(A_.Row(0)), extents.cols, - A_.Stride()); + return StridedViewBF(A_, 0, 0, extents.cols); } private: MatStorageT A_; }; +// One tile of C per *worker* (required for `kNT_MT*`). +class MMTilesC { + public: + explicit MMTilesC(const ThreadingContext& ctx) { + const size_t max_workers = ctx.pools.MaxWorkers(); + C_.reserve(max_workers); + for (size_t worker = 0; worker < max_workers; ++worker) { + C_.push_back(MatStorageT("Ctile", Extents2D(kMaxBatchSize, kMaxNC), + ctx.allocator, MatPadding::kOdd)); + } + } + + StridedViewBF C(const Extents2D& extents, size_t worker) const { + HWY_DASSERT(extents.rows <= kMaxBatchSize); + HWY_DASSERT(worker < C_.size()); + return StridedViewBF(C_[worker], 0, 0, extents.cols); + } + + private: + std::vector> C_; +}; + //------------------------------------------------------------------------------ // Autotuning @@ -471,7 +491,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) std::vector MMCandidates(const CacheInfo& cache, size_t M, size_t K, - size_t N, size_t sizeof_TC, + size_t N, size_t num_B, size_t sizeof_TC, bool print_config); // State machine for choosing the best `TConfig`, which is `MMConfig` for the @@ -595,12 +615,14 @@ class MMKeys { static constexpr Key kPadding = 0; // Compresses the dimensions into a single Key for faster comparison. - static Key KeyFromDims(size_t M, size_t K, size_t N) { + static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) { HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller - HWY_DASSERT(K < (Key{1} << 24)); - HWY_DASSERT(N < (Key{1} << 24)); + HWY_DASSERT(K < (Key{1} << 20)); + HWY_DASSERT(N < (Key{1} << 20)); + HWY_DASSERT(num_B == 1 || num_B == 2); const Key key = static_cast(BucketM(M)) | (static_cast(K) << 16) | - (static_cast(N) << 40); + (static_cast(N) << 40) | + (static_cast(num_B) << 60); HWY_DASSERT(key != kPadding); return key; } @@ -643,10 +665,6 @@ class MMKeys { // Per-MatMul-shape state. struct MMPerKey { - // Only profile if enabled and the main autotuner finished. `autotune_par_a` - // might not be active if inputs are all BF16. - bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); } - MMAutoTune autotune; MMAutoTune autotune_par_a; }; @@ -666,12 +684,15 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - MMStorage storage; + MMEntireA A_BF; + MMTilesC C_tiles; struct PerCluster { MMKeys keys; std::vector per_key; - HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing + // Prevents false sharing. + HWY_MAYBE_UNUSED uint8_t + padding[HWY_ALIGNMENT - sizeof(MMKeys) - sizeof(per_key)]; }; std::vector per_cluster; @@ -687,31 +708,57 @@ struct MatMulEnv { std::vector> row_ptrs; }; -// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols) -// that this thread has just filled, a view into a second tile (only for the -// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`). -using MMFused = std::function; +// Called via `CallClosure`, which consumes the first (opaque) argument. User +// functions are called with the entire C matrix, the sub-ranges of M (rows) +// and N (cols) that this thread has just filled, a view into a second tile +// (only for `TwoMatmul`), and the worker thread index (see `ParallelFor`). +typedef void (*MMFunc)(const void* opaque, RowPtrsBF, IndexRange, IndexRange, + StridedViewBF, size_t); + +class MMOptions { + // Same technique as in `hwy::ThreadPool` and C++23 `std::function_ref`: + // type-erasure without allocation. + template + static void CallClosure(const void* opaque, RowPtrsBF C1, IndexRange range_r, + IndexRange range_c, StridedViewBF C2, size_t worker) { + (*reinterpret_cast(opaque))(C1, range_r, range_c, C2, + worker); + } + + public: + // `closure` must remain alive until the end of (Two)MatMul. + template + void SetFunc(const Closure& closure) { + func = static_cast(&CallClosure); + opaque = &closure; + } + + void MaybeCallFunc(RowPtrsBF C1, IndexRange range_r, IndexRange range_c, + StridedViewBF C2, size_t worker) const { + if (func != nullptr) { + func(opaque, C1, range_r, range_c, C2, worker); + } + } + + MMFunc func = nullptr; // called if non-null and `TC` is BF16. + const void* opaque = nullptr; -struct MMOptions { uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; - - MMFused fused; // called if non-null and `TC` is BF16. }; // Arguments to MatMul() that are independent of the A/B/C types. Reduces // register pressure compared to individual values/references. Also used for // passing through `DispatchOrder`. struct MMArgs { - MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale, + MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, float scale_A, const float* HWY_RESTRICT add, MMOptions options, const MMAutoTune& autotune, const MMConfig& config) : env(env), line_bytes(env.ctx.cache_info.LineBytes()), range_n(0, N), - scale(scale), + scale_A(scale_A), add(add), options(options), @@ -728,7 +775,8 @@ struct MMArgs { // MatMul arguments: const IndexRange range_n; // entire N - const double scale; + // There can be two B, so do not yet multiply together the A and B scales. + const float scale_A; const float* HWY_RESTRICT add; const MMOptions options; diff --git a/ops/matmul_static-inl.h b/ops/matmul_static-inl.h index ba09e0c..abb6f43 100644 --- a/ops/matmul_static-inl.h +++ b/ops/matmul_static-inl.h @@ -53,6 +53,14 @@ namespace HWY_NAMESPACE { // included from matmul_static_*.cc. GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT +HWY_MAYBE_UNUSED void TwoMatMulStatic(const MatPtrT& A, // NOLINT + const MatPtrT& B1, + const MatPtrT& B2, + MatMulEnv& env, MatPtrT& C, + MMOptions options) { + TwoMatMul(A, B1, B2, env, C, options); +} + } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); diff --git a/ops/matmul_static.h b/ops/matmul_static.h index 61dc505..6b93d92 100644 --- a/ops/matmul_static.h +++ b/ops/matmul_static.h @@ -37,13 +37,19 @@ const float* HWY_RESTRICT add, MatMulEnv& env, \ MatPtrT& C, MMOptions options); +#define GEMMA_MATMUL_FOR_B(TB) \ + GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, TB) \ + void TwoMatMulStatic(const MatPtrT& A, const MatPtrT& B1, \ + const MatPtrT& B2, MatMulEnv& env, \ + MatPtrT& C, MMOptions options); + // Passed to HWY_VISIT_TARGETS; declares all overloads for all targets. #define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \ namespace NAMESPACE { \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \ - GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \ + GEMMA_MATMUL_FOR_B(BF16) \ + GEMMA_MATMUL_FOR_B(float) \ + GEMMA_MATMUL_FOR_B(NuqStream) \ + GEMMA_MATMUL_FOR_B(SfpStream) \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 373f8aa..2f0fde2 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -29,6 +29,8 @@ #include #include +#include + #include "ops/matmul.h" #include "util/basics.h" #include "util/mat.h" @@ -246,7 +248,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatStorageT C_slow("C_slow", C_extents, env.ctx.allocator, MatPadding::kOdd); MatStorageT C("C", C_extents, env.ctx.allocator, MatPadding::kOdd); + MatStorageT C2("C", C_extents, env.ctx.allocator, MatPadding::kOdd); C.AllocateAndAttachRowPtrs(env.row_ptrs); + C2.AllocateAndAttachRowPtrs(env.row_ptrs); MatStorageT add_storage = add ? GenerateMat(Extents2D(1, cols_bc), env.ctx.allocator, @@ -262,7 +266,48 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, for (size_t rep = 0; rep < 16; ++rep) { MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options); AssertClose(A, BT, C_slow, C, env, line); - if (per_key->autotune.Best()) break; + // Check before TwoMatMulStatic(), which can invalidate per_key. + const bool autotune_done = !!per_key->autotune.Best(); + + // Ensure the tiled view returns the same result as C. + if constexpr (IsBF16() && IsBF16()) { + // The total view area should match the entire C matrix. + std::atomic total_view_area = 0; + + const auto fused = [&](RowPtrsBF C2_rows, IndexRange range_r, + IndexRange range_c, StridedViewBF C2_view, + size_t worker) { + total_view_area.fetch_add(range_r.Num() * range_c.Num()); + HWY_ASSERT(range_c.Num() <= C2_view.Cols()); + HWY_ASSERT(worker < env.ctx.pools.MaxWorkers()); + for (size_t ir = 0; ir < range_r.Num(); ++ir) { + const size_t r = range_r.begin() + ir; + for (size_t ic = 0; ic < range_c.Num(); ++ic) { + const size_t c = range_c.begin() + ic; + const float expected = + hwy::ConvertScalarTo(C2_rows.Row(r)[c]); + const float actual = + hwy::ConvertScalarTo(C2_view.Row(ir)[ic]); + const float L1 = hwy::ScalarAbs(actual - expected); + if (L1 > 1E-6f) { + HWY_ABORT("%zu: ir %zu ic %zu L1 %f expected %f actual %f.", + worker, ir, ic, L1, expected, actual); + } + } + } + }; + options.SetFunc(fused); + TwoMatMulStatic(A, BT, BT, env, C2, options); + HWY_ASSERT_EQ(C.Extents().Area(), total_view_area.load()); + options.func = nullptr; // reset for next call + + // TwoMatMulStatic() does not support adding a bias vector. + if (!add) { + AssertClose(A, BT, C, C2, env, line); + } + } + + if (autotune_done) break; } } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index cfd85ae..1593aa4 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -69,6 +69,14 @@ MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, }); } +static inline void CallTwoMatMul(const MatPtrT& A, const MatPtr& B1, + const MatPtr& B2, MatMulEnv& env, + MatPtrT& C, const MMOptions& options) { + return CallUpcastedSame(&B1, &B2, [&](const auto* B1_t, const auto* B2_t) { + return TwoMatMulStatic(A, *B1_t, *B2_t, env, C, options); + }); +} + HWY_INLINE double PackTokenAndProb(int32_t token, float prob) { // casting prob from float to double just makes some changes to the // exponent bias and pads zeros in the mantissa. diff --git a/util/mat.h b/util/mat.h index c2427e5..6f9a243 100644 --- a/util/mat.h +++ b/util/mat.h @@ -40,10 +40,11 @@ class RowPtrs { public: RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {} - RowPtrs View(size_t r, size_t c) { + // Extra argument is for compatibility with `StridedView`. + RowPtrs View(size_t r, size_t c, size_t /*cols*/) { RowPtrs view(row_ptrs_); - view.r0_ = static_cast(r); - view.c0_ = static_cast(c); + view.r0_ = static_cast(r0_ + r); + view.c0_ = static_cast(c0_ + c); return view; } @@ -531,7 +532,11 @@ class StridedView { : row0_(row0), cols_(static_cast(cols)), stride_(static_cast(stride)) { - HWY_DASSERT(stride >= cols); + if constexpr (HWY_IS_DEBUG_BUILD) { + if (stride < cols) { + HWY_ABORT("stride %zu < cols %zu", stride, cols); + } + } } // Returns 2D subrange whose top-left is `r, c` and width is `cols`.