mirror of https://github.com/google/gemma.cpp.git
1.03x speedup: fused FFN
matmul-inl: support CView=StridedView or RowPtrs; rename to C_MC_NC matmul.cc: Allow 1 more rep for MC/NC to allow half-sized tiles, which helps. PiperOrigin-RevId: 807291701
This commit is contained in:
parent
59db30e209
commit
f3bc1c17da
|
|
@ -37,7 +37,6 @@ class GemmaBatchBench : public ::testing::Test {
|
|||
protected:
|
||||
std::vector<std::string> BatchGemmaReply(
|
||||
const std::vector<std::string>& inputs) {
|
||||
s_env->SetMaxGeneratedTokens(24);
|
||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||
s_env->MutableConfig().verbosity = 2;
|
||||
std::vector<std::string> replies;
|
||||
|
|
@ -92,15 +91,18 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
|||
inputs.push_back(questions[qpos++]);
|
||||
if (qpos == questions.size()) qpos = 0;
|
||||
}
|
||||
s_env->SetMaxGeneratedTokens(24);
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
||||
++i) {
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
}
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
|
||||
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
||||
// because those are already fast.
|
||||
s_env->SetMaxGeneratedTokens(3);
|
||||
s_env->SetMaxGeneratedTokens(2);
|
||||
responses = BatchGemmaReply(inputs);
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
|
|||
|
||||
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
|
||||
|
||||
#ifndef GEMMA_FUSED_FFN
|
||||
#define GEMMA_FUSED_FFN 1
|
||||
#endif // !GEMMA_FUSED_FFN
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class PromptWrapping {
|
||||
GEMMA_IT,
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// For use by Vit even if !GEMMA_FUSED_FFN.
|
||||
template <typename T1, typename T2>
|
||||
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
||||
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
||||
|
|
@ -64,7 +65,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
|||
});
|
||||
}
|
||||
|
||||
// No C2 multiplier.
|
||||
// No C2 multiplier - used by Vit.
|
||||
template <class Mat>
|
||||
void ActivationBatched(
|
||||
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||
|
|
@ -80,6 +81,34 @@ void ActivationBatched(
|
|||
});
|
||||
}
|
||||
|
||||
#if GEMMA_FUSED_FFN
|
||||
|
||||
// Called during `TwoMatMul`.
|
||||
static inline void Activation(ActivationType activation, const RowPtrsBF C1,
|
||||
const IndexRange range_r,
|
||||
const IndexRange range_c, const StridedViewBF C2,
|
||||
hwy::Profiler& p, const size_t worker) {
|
||||
static const auto zone = p.AddZone("Gen.ActivationFused");
|
||||
PROFILER_ZONE3(p, worker, zone);
|
||||
|
||||
const size_t cols = range_c.Num();
|
||||
HWY_DASSERT(C2.Cols() == cols);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
// ActivationType::Gelu
|
||||
// Gated: Gelu(c1) * c2.
|
||||
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
|
||||
Decompress1AndCompressInplace(
|
||||
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
|
||||
[](DF df, VF v1, VF v2)
|
||||
HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); });
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <class Mat1, class Mat2>
|
||||
HWY_NOINLINE void ActivationBatched(
|
||||
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||
|
|
@ -102,6 +131,8 @@ HWY_NOINLINE void ActivationBatched(
|
|||
}
|
||||
}
|
||||
|
||||
#endif // GEMMA_FUSED_FFN
|
||||
|
||||
template <typename T2, class LayerWeights>
|
||||
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
|
||||
MatPtrT<float>& HWY_RESTRICT x,
|
||||
|
|
@ -126,28 +157,32 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
|
||||
|
||||
const bool add_bias = layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr;
|
||||
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
|
||||
const float* output_bias =
|
||||
add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr;
|
||||
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
|
||||
|
||||
#if GEMMA_FUSED_FFN
|
||||
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
||||
StridedViewBF C2, size_t worker) {
|
||||
Activation(layer_config.activation, C1, range_r, range_c, C2,
|
||||
env.ctx.profiler, worker);
|
||||
};
|
||||
MMOptions options;
|
||||
options.SetFunc(fused);
|
||||
CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1,
|
||||
layer.gating_einsum_w2, env, activations.C1, options);
|
||||
#else
|
||||
// Compute the hidden layer activations.
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env,
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env,
|
||||
activations.C1);
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env,
|
||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env,
|
||||
activations.C2);
|
||||
|
||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
||||
env.ctx);
|
||||
#endif
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
||||
activations.ffw_out);
|
||||
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
404
ops/matmul-inl.h
404
ops/matmul-inl.h
|
|
@ -155,14 +155,14 @@ class MMStoreHorizontalSumsIntoC {
|
|||
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CView>
|
||||
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
const size_t imc, Tag tag, CView C_rows) const {
|
||||
const size_t imc, Tag tag, CView C_MC_NR) const {
|
||||
const V4 vscale = hn::Set(d4, scale);
|
||||
HWY_ALIGN static constexpr float kZero[4] = {};
|
||||
const V4 vadd = hn::Load(d4, add ? add : kZero);
|
||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows);
|
||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows);
|
||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows);
|
||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows);
|
||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_MC_NR);
|
||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_MC_NR);
|
||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_MC_NR);
|
||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_MC_NR);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -202,10 +202,10 @@ class MMStoreHorizontalSumsIntoC {
|
|||
class Tag, class CView>
|
||||
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||
VF4 vadd, Tag, const size_t imc,
|
||||
CView C_view) {
|
||||
CView C_MC_NR) {
|
||||
if constexpr (kRow < kRowsAC) {
|
||||
using TC = hwy::RemoveCvRef<decltype(C_view.Row(0)[0])>;
|
||||
TC* HWY_RESTRICT pos = C_view.Row(imc + kRow);
|
||||
using TC = hwy::RemoveCvRef<decltype(C_MC_NR.Row(0)[0])>;
|
||||
TC* HWY_RESTRICT pos = C_MC_NR.Row(imc + kRow);
|
||||
const hn::Rebind<TC, DF4> dc4;
|
||||
if constexpr (hwy::IsSame<Tag, MMAddC>()) {
|
||||
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
|
||||
|
|
@ -268,9 +268,9 @@ class MMDecompress {
|
|||
} else {
|
||||
// Always decompress. To reduce code size/compile time, we no longer
|
||||
// support a separate F32 kernel; most A are already BF16. We also only
|
||||
// have a single MMStorage.
|
||||
// have a single MMEntireA.
|
||||
HWY_ASSERT(options.cluster_idx == 0);
|
||||
const StridedViewBF A_view = env.storage.A(A.Extents());
|
||||
const StridedViewBF A_view = env.A_BF.A(A.Extents());
|
||||
AutotuneDecompressA(A, A_view, autotune, env, options);
|
||||
return A_view;
|
||||
}
|
||||
|
|
@ -387,111 +387,52 @@ class MMDecompress {
|
|||
|
||||
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
|
||||
class MMKernel {
|
||||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||
// allocation avoids passing a worker index.
|
||||
static constexpr size_t B_stride_max =
|
||||
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
||||
|
||||
public:
|
||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
|
||||
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
|
||||
// at row/col 0. `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
|
||||
// Called by B3A2C0 and by callers that hoist `A_view`.
|
||||
template <class Tag, class CView>
|
||||
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
||||
const StridedViewBF B_view, size_t mr,
|
||||
const IndexRange& range_mc, size_t kc,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
Tag tag, CView C_view) {
|
||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||
|
||||
const size_t mc = range_mc.Num();
|
||||
size_t imc = 0;
|
||||
|
||||
// M == 1, or x86 with 8 SIMD registers:
|
||||
if (HWY_UNLIKELY(mr == 1)) {
|
||||
for (; imc < mc; ++imc) {
|
||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// AVX2 (16 registers)
|
||||
if (HWY_UNLIKELY(mr == 2)) {
|
||||
if (HWY_LIKELY(mc >= 2)) {
|
||||
for (; imc <= mc - 2; imc += 2) {
|
||||
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(imc != mc)) {
|
||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
HWY_DASSERT(mr == 4);
|
||||
if (HWY_LIKELY(mc >= 4)) {
|
||||
for (; imc <= mc - 4; imc += 4) {
|
||||
LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||
}
|
||||
}
|
||||
const size_t remainder_mc = mc - imc;
|
||||
HWY_DASSERT(remainder_mc < 4);
|
||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||
imc += 2;
|
||||
}
|
||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
||||
imc += 1;
|
||||
}
|
||||
HWY_DASSERT(imc == mc);
|
||||
}
|
||||
|
||||
static constexpr size_t B_storage_max = kNR * B_stride_max;
|
||||
|
||||
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
||||
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
||||
// `ForeachKC` and when there is only a single KC task.
|
||||
template <typename TB, typename TC, typename Tag>
|
||||
// `mc x kc` of A, `nc x kc` of B, and updates the `mc x nc` `C_MC_NC`.
|
||||
// `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
|
||||
template <typename TB, typename Tag, class CView>
|
||||
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
const IndexRange& range_mc, const IndexRange& range_kc,
|
||||
const IndexRange& range_nc, const MMArgs& args,
|
||||
Tag out_tag, RowPtrs<TC> C) {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max];
|
||||
|
||||
Tag out_tag, CView C_MC_NC) {
|
||||
const size_t kc = range_kc.Num();
|
||||
const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
|
||||
// Upper bound on per-worker storage for `kNR` row ranges of B. Stack
|
||||
// allocation avoids passing a worker index.
|
||||
constexpr size_t B_stride_max =
|
||||
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
||||
HWY_ALIGN BF16 B_storage[kNR * B_stride_max];
|
||||
const size_t B_stride =
|
||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes);
|
||||
const StridedViewBF B_storage_view(B_storage, kc, B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
const float scale = args.scale_A * B.Scale();
|
||||
for (size_t inc = 0; inc < range_nc.Num(); inc += kNR) {
|
||||
// For `add` and `B`, which are global, unlike `C_MC_NC`.
|
||||
const size_t row_b = range_nc.begin() + inc;
|
||||
const StridedViewBF B_view =
|
||||
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
const RowPtrs<TC> C_view = C.View(range_mc.begin(), row_b);
|
||||
const CView C_MC_NR = C_MC_NC.View(0, inc, kNR);
|
||||
const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr;
|
||||
A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag,
|
||||
C_view);
|
||||
A2C0(A_view, B_view, args.mr, range_mc, kc, scale, add, out_tag, C_MC_NR);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TB, typename TC>
|
||||
template <typename TB, class CView>
|
||||
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
const IndexRange& range_mc,
|
||||
const IndexRangePartition& ranges_kc,
|
||||
const IndexRange& range_nc, const MMArgs& args,
|
||||
RowPtrs<TC> C) {
|
||||
CView C_MC_NC) {
|
||||
// Peel off the first iteration of the kc loop: avoid zero-initializing `C`
|
||||
// by writing directly into it, and later accumulating into it.
|
||||
ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
|
||||
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C);
|
||||
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C_MC_NC);
|
||||
});
|
||||
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
|
||||
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C);
|
||||
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C_MC_NC);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -585,15 +526,15 @@ class MMKernel {
|
|||
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
||||
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
||||
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
||||
// Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0.
|
||||
// `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
|
||||
// Updates a `kRowsAC x kNR` tile in `C_MC_NR` starting at row `imc`, column
|
||||
// 0. `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
|
||||
// relative to the C column.
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, class CView>
|
||||
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
||||
const StridedViewBF B_view, size_t imc,
|
||||
size_t kc, const float scale,
|
||||
const float* HWY_RESTRICT add, Tag tag,
|
||||
CView C_view) {
|
||||
CView C_MC_NR) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
|
@ -777,7 +718,62 @@ class MMKernel {
|
|||
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
|
||||
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
|
||||
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
|
||||
horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view);
|
||||
horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_MC_NR);
|
||||
}
|
||||
|
||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
|
||||
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
|
||||
// at row/col 0.
|
||||
template <class Tag, class CView>
|
||||
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
||||
const StridedViewBF B_view, size_t mr,
|
||||
const IndexRange& range_mc, size_t kc,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
Tag tag, CView C_MC_NR) {
|
||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||
|
||||
const size_t mc = range_mc.Num();
|
||||
size_t imc = 0;
|
||||
|
||||
// M == 1, or x86 with 8 SIMD registers:
|
||||
if (HWY_UNLIKELY(mr == 1)) {
|
||||
for (; imc < mc; ++imc) {
|
||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// AVX2 (16 registers)
|
||||
if (HWY_UNLIKELY(mr == 2)) {
|
||||
if (HWY_LIKELY(mc >= 2)) {
|
||||
for (; imc <= mc - 2; imc += 2) {
|
||||
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(imc != mc)) {
|
||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
HWY_DASSERT(mr == 4);
|
||||
if (HWY_LIKELY(mc >= 4)) {
|
||||
for (; imc <= mc - 4; imc += 4) {
|
||||
LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||
}
|
||||
}
|
||||
const size_t remainder_mc = mc - imc;
|
||||
HWY_DASSERT(remainder_mc < 4);
|
||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||
imc += 2;
|
||||
}
|
||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||
imc += 1;
|
||||
}
|
||||
HWY_DASSERT(imc == mc);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -813,10 +809,10 @@ class MMImpl {
|
|||
}
|
||||
|
||||
public:
|
||||
static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N,
|
||||
static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, size_t num_B,
|
||||
size_t vector_bytes,
|
||||
MatMulEnv::PerCluster& per_cluster) {
|
||||
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
||||
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N, num_B);
|
||||
intptr_t index = IndexOfKey(key, per_cluster.keys);
|
||||
// First time we see this shape/key.
|
||||
if (HWY_UNLIKELY(index < 0)) {
|
||||
|
|
@ -831,17 +827,19 @@ class MMImpl {
|
|||
}
|
||||
|
||||
static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N,
|
||||
double t0, MMAutoTune<MMConfig>& tuner,
|
||||
size_t num_B, double t0,
|
||||
MMAutoTune<MMConfig>& tuner,
|
||||
const MMConfig& cfg) {
|
||||
const uint64_t t1 =
|
||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||
hwy::platform::InvariantTicksPerSecond();
|
||||
const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA
|
||||
const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA
|
||||
if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) {
|
||||
fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9,
|
||||
min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(),
|
||||
StringFromOrder(cfg.Order()), cfg.InnerTasks());
|
||||
fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n",
|
||||
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(),
|
||||
cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()),
|
||||
cfg.InnerTasks());
|
||||
}
|
||||
if (HWY_UNLIKELY(env.print_best && tuner.Best())) {
|
||||
const auto ratio = [&tuner](uint64_t ticks) -> double {
|
||||
|
|
@ -849,12 +847,13 @@ class MMImpl {
|
|||
static_cast<double>(tuner.BestTicks());
|
||||
};
|
||||
const MMConfig& best = *tuner.Best();
|
||||
fprintf(stderr,
|
||||
"\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n",
|
||||
M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
|
||||
best.KC(), best.NC(), StringFromOrder(best.Order()),
|
||||
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
|
||||
ratio(tuner.FirstConfigTicks()));
|
||||
fprintf(
|
||||
stderr,
|
||||
"\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n",
|
||||
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
|
||||
best.KC(), best.NC(), StringFromOrder(best.Order()),
|
||||
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
|
||||
ratio(tuner.FirstConfigTicks()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -874,10 +873,11 @@ class MMImpl {
|
|||
class MMLoops {
|
||||
public:
|
||||
// Called from `MatMul` from two places: either with the next autotune config,
|
||||
// or with the best config.
|
||||
// or with the best config. `B2` is null unless called from `TwoMatMul`.
|
||||
template <typename TB, typename TC>
|
||||
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C, const MMArgs& args) {
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
||||
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
||||
|
|
@ -885,7 +885,7 @@ class MMLoops {
|
|||
DispatchParallelism(
|
||||
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
|
||||
Loop(order, parallel, A, B, C, args);
|
||||
Loop(order, parallel, A, B, B2, C, args);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -901,18 +901,14 @@ class MMLoops {
|
|||
template <typename TB, typename TC, class Parallel>
|
||||
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
|
||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C, const MMArgs& args) {
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
|
||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||
const IndexRange& range_M = args.ranges_mc.Range(0);
|
||||
const IndexRange& range_K = args.ranges_kc.Range(0);
|
||||
const size_t K = range_K.Num();
|
||||
const StridedViewBF A_view = A.View(range_M.begin(), 0, K);
|
||||
const size_t B_stride =
|
||||
Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||
|
||||
// Similar to `B3A2C0`, but here we hoisted `A_view`.
|
||||
parallel.ForN(
|
||||
args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes),
|
||||
args.inner_tasks, args.options.cluster_idx,
|
||||
|
|
@ -920,26 +916,19 @@ class MMLoops {
|
|||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||
|
||||
HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS
|
||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||
MMKernel::B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(),
|
||||
C.View(0, range_nc.begin(), range_nc.Num()));
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
const StridedViewBF B_view =
|
||||
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
|
||||
const RowPtrs<TC> C_view = C.View(range_M.begin(), row_b);
|
||||
const float* HWY_RESTRICT add =
|
||||
args.add ? args.add + row_b : nullptr;
|
||||
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||
|
||||
MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add,
|
||||
MMSetC(), C_view);
|
||||
if (B2 != nullptr) {
|
||||
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
|
||||
MMSetC(), C2);
|
||||
}
|
||||
|
||||
if constexpr (IsBF16<TC>()) {
|
||||
if (args.options.fused) {
|
||||
StridedViewBF C2(nullptr, 0, 0);
|
||||
args.options.fused(C, range_M, range_nc, C2, worker);
|
||||
}
|
||||
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -948,7 +937,8 @@ class MMLoops {
|
|||
template <typename TB, typename TC, class Parallel>
|
||||
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
|
||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C, const MMArgs& args) {
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
|
||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||
|
|
@ -959,14 +949,21 @@ class MMLoops {
|
|||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc,
|
||||
range_nc, args, C);
|
||||
MMKernel::ForeachKC(
|
||||
A, B, range_mc, args.ranges_kc, range_nc, args,
|
||||
C.View(0, range_nc.begin(), range_nc.Num()));
|
||||
|
||||
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||
|
||||
if (B2 != nullptr) {
|
||||
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc,
|
||||
range_nc, args, C2);
|
||||
}
|
||||
|
||||
if constexpr (IsBF16<TC>()) {
|
||||
if (args.options.fused) {
|
||||
StridedViewBF C2(nullptr, 0, 0);
|
||||
args.options.fused(C, range_mc, range_nc, C2, worker);
|
||||
}
|
||||
args.options.MaybeCallFunc(C, range_mc, range_nc, C2,
|
||||
worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -976,10 +973,11 @@ class MMLoops {
|
|||
template <typename TB, typename TC, class Parallel>
|
||||
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
|
||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C, const MMArgs& args) {
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
|
||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||
const IndexRange& range_K = args.ranges_kc.Range(0);
|
||||
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||
|
||||
parallel.ForRangesMC_NC(
|
||||
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
||||
|
|
@ -987,14 +985,19 @@ class MMLoops {
|
|||
size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||
MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(),
|
||||
C);
|
||||
MMKernel::B3A2C0(
|
||||
A, B, range_mc, range_kc, range_nc, args, MMSetC(),
|
||||
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));
|
||||
|
||||
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||
|
||||
if (B2 != nullptr) {
|
||||
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
|
||||
MMSetC(), C2);
|
||||
}
|
||||
if constexpr (IsBF16<TC>()) {
|
||||
if (args.options.fused) {
|
||||
StridedViewBF C2(nullptr, 0, 0);
|
||||
args.options.fused(C, range_mc, range_nc, C2, worker);
|
||||
}
|
||||
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -1004,7 +1007,8 @@ class MMLoops {
|
|||
template <typename TB, typename TC, class Parallel>
|
||||
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
|
||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C, const MMArgs& args) {
|
||||
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||
const MMArgs& args) {
|
||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
|
||||
|
||||
parallel.ForRangesMC_NC(
|
||||
|
|
@ -1013,14 +1017,20 @@ class MMLoops {
|
|||
size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args,
|
||||
C);
|
||||
MMKernel::ForeachKC(
|
||||
A, B, range_mc, args.ranges_kc, range_nc, args,
|
||||
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));
|
||||
|
||||
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||
|
||||
if (B2 != nullptr) {
|
||||
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc,
|
||||
args, C2);
|
||||
}
|
||||
|
||||
if constexpr (IsBF16<TC>()) {
|
||||
if (args.options.fused) {
|
||||
StridedViewBF C2(nullptr, 0, 0);
|
||||
args.options.fused(C, range_mc, range_nc, C2, worker);
|
||||
}
|
||||
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -1060,20 +1070,23 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
const size_t M = A.Rows();
|
||||
const size_t K = A.Cols();
|
||||
const size_t N = B.Rows();
|
||||
const size_t num_B = 1;
|
||||
|
||||
const CacheInfo& cache = env.ctx.cache_info;
|
||||
MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(),
|
||||
env.per_cluster[cluster_idx]);
|
||||
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
|
||||
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
|
||||
|
||||
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
||||
const StridedViewBF A_view =
|
||||
MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
||||
|
||||
MatPtrT<TB>* B2 = nullptr; // required for type matching
|
||||
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
|
||||
add, options, tuner, *tuner.Best());
|
||||
MMLoops::Dispatch(A_view, B, C_rows, args);
|
||||
const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner,
|
||||
*tuner.Best());
|
||||
MMLoops::Dispatch(A_view, B, B2, C_rows, args);
|
||||
return &per_key;
|
||||
}
|
||||
|
||||
|
|
@ -1082,20 +1095,83 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
// Ensure matrix dimensions match each other (off the hot path).
|
||||
HWY_ASSERT(K == B.Cols());
|
||||
HWY_ASSERT(M <= kMaxBatchSize);
|
||||
HWY_ASSERT(K <= MMStorage::kMaxK);
|
||||
HWY_ASSERT(K <= MMEntireA::kMaxK);
|
||||
HWY_ASSERT(N % kNR == 0);
|
||||
MMImpl::EnsureAligned(A, cache.VectorBytes());
|
||||
tuner.SetCandidates(
|
||||
MMCandidates(cache, M, K, N, sizeof(TC), env.print_config));
|
||||
MMCandidates(cache, M, K, N, num_B, sizeof(TC), env.print_config));
|
||||
}
|
||||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
|
||||
add, options, tuner, cfg);
|
||||
const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner, cfg);
|
||||
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
MMLoops::Dispatch(A_view, B, C_rows, args);
|
||||
MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg);
|
||||
MMLoops::Dispatch(A_view, B, B2, C_rows, args);
|
||||
MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg);
|
||||
|
||||
return &per_key;
|
||||
}
|
||||
|
||||
// Performs A*B1 and A*B2 in parallel. This is useful for gated FFNs.
|
||||
// Differences vs MatMul: The second result matrix is not materialized, it is
|
||||
// only passed to the `options.func` callback. There is no `add` argument
|
||||
// because it is not required for this use case. There is no default `options`
|
||||
// argument because `options.func` must be set by the caller.
|
||||
template <typename TB>
|
||||
HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
|
||||
const MatPtrT<TB>& B2, MatMulEnv& env,
|
||||
MatPtrT<BF16>& C, MMOptions options) {
|
||||
static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul");
|
||||
const size_t cluster_idx = options.cluster_idx;
|
||||
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
|
||||
|
||||
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.
|
||||
|
||||
RowPtrs<BF16> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
||||
|
||||
const size_t M = A.Rows();
|
||||
const size_t K = A.Cols();
|
||||
const size_t N = B1.Rows();
|
||||
const size_t num_B = 2;
|
||||
|
||||
const CacheInfo& cache = env.ctx.cache_info;
|
||||
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
|
||||
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
|
||||
|
||||
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
||||
const StridedViewBF A_view(A, 0, 0, A.Cols());
|
||||
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
// Only A scale - B1/B2 may differ, and are passed separately.
|
||||
const MMArgs args(env, M, K, N, A.Scale(),
|
||||
/*add=*/nullptr, options, tuner, *tuner.Best());
|
||||
MMLoops::Dispatch(A_view, B1, &B2, C_rows, args);
|
||||
return &per_key;
|
||||
}
|
||||
|
||||
// Autotuning, first call: enumerate all feasible configs.
|
||||
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
|
||||
// Ensure matrix dimensions match each other (off the hot path).
|
||||
HWY_ASSERT(K == B1.Cols());
|
||||
HWY_ASSERT(K == B2.Cols());
|
||||
HWY_ASSERT(M <= kMaxBatchSize);
|
||||
HWY_ASSERT(K <= MMEntireA::kMaxK);
|
||||
HWY_ASSERT(N % kNR == 0);
|
||||
MMImpl::EnsureAligned(A, cache.VectorBytes());
|
||||
tuner.SetCandidates(
|
||||
MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config));
|
||||
}
|
||||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
// Only A scale - B1/B2 may differ, and are passed separately.
|
||||
const MMArgs args(env, M, K, N, A.Scale(), /*add=*/nullptr, options, tuner,
|
||||
cfg);
|
||||
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
MMLoops::Dispatch(A_view, B1, &B2, C_rows, args);
|
||||
MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg);
|
||||
|
||||
return &per_key;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,11 +63,12 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
|||
class GenerateCandidates {
|
||||
public:
|
||||
GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N,
|
||||
size_t sizeof_TC, bool print_config)
|
||||
size_t num_B, size_t sizeof_TC, bool print_config)
|
||||
: cache_(cache),
|
||||
M_(M),
|
||||
K_(K),
|
||||
N_(N),
|
||||
num_B_(num_B),
|
||||
sizeof_TC_(sizeof_TC),
|
||||
// These influence kc/nc, but are also stored in `MMConfig` for
|
||||
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
||||
|
|
@ -150,7 +151,7 @@ class GenerateCandidates {
|
|||
}
|
||||
}
|
||||
|
||||
// The number of A and B columns to read between updating `partial`.
|
||||
// The number of A and B columns to read between updating `C`.
|
||||
SizeVec KC(size_t mr, MMOrder order) const {
|
||||
// `LoopKC` handles up to `mr` rows of A.
|
||||
const size_t rows_a = HWY_MIN(M_, mr);
|
||||
|
|
@ -164,7 +165,7 @@ class GenerateCandidates {
|
|||
// TB=NUQ due to less amortization of the table loads. Due to the low L1
|
||||
// latency, the packing is still effectively fused into `LoopKC`. It may
|
||||
// be better to round up and accept a few L2 accesses in exchange for
|
||||
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
|
||||
// fewer loops over K, and thus fewer writes to `C`. Hence we do not
|
||||
// subtract the output and buf, and allow using more than the actual L1
|
||||
// size. This results in an overestimate, and the loop below will propose
|
||||
// the next few smaller values for the autotuner to evaluate.
|
||||
|
|
@ -179,7 +180,7 @@ class GenerateCandidates {
|
|||
|
||||
// Avoid proposing kc > K.
|
||||
if (K_ > kc_multiple_) {
|
||||
// Generally it is best to use the full `kc` (fewer writes to `partial`),
|
||||
// Generally it is best to use the full `kc` (fewer writes to `C`),
|
||||
// but a bit less can be better if it evenly divides `K`, or enables an
|
||||
// `mc` that evenly divides `M`. Try several smaller values.
|
||||
|
||||
|
|
@ -196,7 +197,7 @@ class GenerateCandidates {
|
|||
}
|
||||
|
||||
if (print_config_ && all_kc.size() > 1) {
|
||||
fprintf(stderr, "KC: ");
|
||||
fprintf(stderr, "num_B %zu: KC: ", num_B_);
|
||||
for (size_t kc : all_kc) {
|
||||
fprintf(stderr, "%zu ", kc);
|
||||
}
|
||||
|
|
@ -214,18 +215,18 @@ class GenerateCandidates {
|
|||
|
||||
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
||||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
||||
// partial.
|
||||
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` C rows.
|
||||
const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes();
|
||||
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||
mc_max = HWY_MIN(mc_max, kMaxBatchSize);
|
||||
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
|
||||
HWY_DASSERT(mc_max != 0);
|
||||
mc_max = HWY_MIN(mc_max, M_);
|
||||
mc_max = hwy::RoundDownTo(mc_max, mr);
|
||||
|
||||
SizeVec all_mc(1, mc_max);
|
||||
// Larger MC is better for non-blocks, otherwise we want more small options.
|
||||
const size_t reps = !IsBlock(order) ? 2 : 3;
|
||||
// Larger MC is better for non-blocks, otherwise we want more small options,
|
||||
// especially for two B.
|
||||
const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_);
|
||||
|
||||
size_t prev = mc_max;
|
||||
for (size_t rep = 0; rep < reps; ++rep) {
|
||||
|
|
@ -240,7 +241,7 @@ class GenerateCandidates {
|
|||
}
|
||||
|
||||
if (print_config_ && all_mc.size() > 1) {
|
||||
fprintf(stderr, "MC: ");
|
||||
fprintf(stderr, "num_B %zu: MC: ", num_B_);
|
||||
for (size_t mc : all_mc) {
|
||||
fprintf(stderr, "%zu ", mc);
|
||||
}
|
||||
|
|
@ -252,14 +253,15 @@ class GenerateCandidates {
|
|||
|
||||
// The number of (possibly L3 resident) B rows per `NT_MT` task.
|
||||
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
|
||||
size_t nc_max = N_;
|
||||
size_t nc_max = kMaxNC;
|
||||
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
|
||||
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
|
||||
// Otherwise, leave it unbounded.
|
||||
// such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise,
|
||||
// leave it unbounded.
|
||||
if (M_ > mr) {
|
||||
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
|
||||
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), N_);
|
||||
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC);
|
||||
}
|
||||
nc_max = HWY_MIN(nc_max, N_);
|
||||
HWY_DASSERT(nc_max != 0);
|
||||
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
|
||||
|
||||
|
|
@ -278,7 +280,7 @@ class GenerateCandidates {
|
|||
if (N_ > nc_multiple_) {
|
||||
// Large L3, but its behavior and characteristics varies across platforms,
|
||||
// hence autotune a wider range of nc than the other dimensions.
|
||||
size_t reps = 10;
|
||||
size_t reps = 9 + num_B_;
|
||||
// For small M, we can afford larger NC, hence allow fewer small options.
|
||||
if (M_ <= 2 * mr) reps -= 1;
|
||||
|
||||
|
|
@ -301,7 +303,7 @@ class GenerateCandidates {
|
|||
}
|
||||
|
||||
if (print_config_ && all_nc.size() > 1) {
|
||||
fprintf(stderr, "NC: ");
|
||||
fprintf(stderr, "num_B %zu: NC: ", num_B_);
|
||||
for (size_t nc : all_nc) {
|
||||
fprintf(stderr, "%zu ", nc);
|
||||
}
|
||||
|
|
@ -329,6 +331,7 @@ class GenerateCandidates {
|
|||
const size_t M_;
|
||||
const size_t K_;
|
||||
const size_t N_;
|
||||
const size_t num_B_;
|
||||
const size_t sizeof_TC_;
|
||||
|
||||
const size_t kc_multiple_;
|
||||
|
|
@ -341,12 +344,13 @@ class GenerateCandidates {
|
|||
|
||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
||||
size_t N, size_t sizeof_TC,
|
||||
size_t N, size_t num_B, size_t sizeof_TC,
|
||||
bool print_config) {
|
||||
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
|
||||
return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)();
|
||||
}
|
||||
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) {
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
||||
: ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) {
|
||||
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
|
||||
per_cluster.resize(num_clusters);
|
||||
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||
|
|
|
|||
126
ops/matmul.h
126
ops/matmul.h
|
|
@ -21,7 +21,6 @@
|
|||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
|
|
@ -54,7 +53,9 @@ HWY_INLINE_VAR constexpr size_t kNR = 4;
|
|||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
|
||||
|
||||
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink?
|
||||
// For `MMTilesC`.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxMC = 512;
|
||||
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
|
||||
|
||||
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
|
|
@ -108,9 +109,9 @@ struct MMParallelWithinCluster {
|
|||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t base = ctx.Worker(cluster_idx);
|
||||
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
const IndexRangePartition ranges_n = StaticPartition(
|
||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
ParallelizeOneRange(worker_ranges, cluster,
|
||||
ParallelizeOneRange(ranges_n, cluster,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, base + worker);
|
||||
});
|
||||
|
|
@ -169,20 +170,20 @@ struct MMParallelHierarchical {
|
|||
if (num_clusters == 1) {
|
||||
const size_t cluster_idx = 0;
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
const IndexRangePartition ranges_n = StaticPartition(
|
||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||
return ParallelizeOneRange(
|
||||
worker_ranges, cluster,
|
||||
ranges_n, cluster,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, worker);
|
||||
});
|
||||
}
|
||||
|
||||
// Assign each cluster a sub-range of `range_n` (typically hundreds).
|
||||
const IndexRangePartition n_ranges =
|
||||
const IndexRangePartition ranges_n =
|
||||
StaticPartition(range_n, num_clusters, n_multiple);
|
||||
ParallelizeOneRange(
|
||||
n_ranges, all_clusters,
|
||||
ranges_n, all_clusters,
|
||||
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||
|
|
@ -274,32 +275,51 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
|
|||
// C is BF16/float.
|
||||
void BindC(ThreadingContext& ctx, MatPtr& C);
|
||||
|
||||
// For A.
|
||||
class MMStorage {
|
||||
// Space for converting A=F32 to BF16 before the matmul. This is faster than
|
||||
// on-the-fly when native BF16 is available: it only happens once, not per B
|
||||
// tile row, and the cache footprint is smaller.
|
||||
class MMEntireA {
|
||||
public:
|
||||
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
||||
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
|
||||
static constexpr size_t kMaxK = 36 * 1024;
|
||||
|
||||
MMStorage(const Allocator& allocator)
|
||||
explicit MMEntireA(const Allocator& allocator)
|
||||
// 288 MiB. Must be padded, see `DoDecompressA`.
|
||||
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
|
||||
MatPadding::kOdd) {}
|
||||
|
||||
// Returns matrix view. Converting A=F32 to BF16 up-front is faster than
|
||||
// on-the-fly when native BF16 is available: it only happens once, not per B
|
||||
// tile row, and the cache footprint is smaller.
|
||||
StridedViewBF A(const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
return StridedViewBF(const_cast<BF16*>(A_.Row(0)), extents.cols,
|
||||
A_.Stride());
|
||||
return StridedViewBF(A_, 0, 0, extents.cols);
|
||||
}
|
||||
|
||||
private:
|
||||
MatStorageT<BF16> A_;
|
||||
};
|
||||
|
||||
// One tile of C per *worker* (required for `kNT_MT*`).
|
||||
class MMTilesC {
|
||||
public:
|
||||
explicit MMTilesC(const ThreadingContext& ctx) {
|
||||
const size_t max_workers = ctx.pools.MaxWorkers();
|
||||
C_.reserve(max_workers);
|
||||
for (size_t worker = 0; worker < max_workers; ++worker) {
|
||||
C_.push_back(MatStorageT<BF16>("Ctile", Extents2D(kMaxBatchSize, kMaxNC),
|
||||
ctx.allocator, MatPadding::kOdd));
|
||||
}
|
||||
}
|
||||
|
||||
StridedViewBF C(const Extents2D& extents, size_t worker) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
||||
HWY_DASSERT(worker < C_.size());
|
||||
return StridedViewBF(C_[worker], 0, 0, extents.cols);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<MatStorageT<BF16>> C_;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Autotuning
|
||||
|
||||
|
|
@ -471,7 +491,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
|||
#pragma pack(pop)
|
||||
|
||||
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
||||
size_t N, size_t sizeof_TC,
|
||||
size_t N, size_t num_B, size_t sizeof_TC,
|
||||
bool print_config);
|
||||
|
||||
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
|
||||
|
|
@ -595,12 +615,14 @@ class MMKeys {
|
|||
static constexpr Key kPadding = 0;
|
||||
|
||||
// Compresses the dimensions into a single Key for faster comparison.
|
||||
static Key KeyFromDims(size_t M, size_t K, size_t N) {
|
||||
static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) {
|
||||
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
||||
HWY_DASSERT(K < (Key{1} << 24));
|
||||
HWY_DASSERT(N < (Key{1} << 24));
|
||||
HWY_DASSERT(K < (Key{1} << 20));
|
||||
HWY_DASSERT(N < (Key{1} << 20));
|
||||
HWY_DASSERT(num_B == 1 || num_B == 2);
|
||||
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
|
||||
(static_cast<Key>(N) << 40);
|
||||
(static_cast<Key>(N) << 40) |
|
||||
(static_cast<Key>(num_B) << 60);
|
||||
HWY_DASSERT(key != kPadding);
|
||||
return key;
|
||||
}
|
||||
|
|
@ -643,10 +665,6 @@ class MMKeys {
|
|||
|
||||
// Per-MatMul-shape state.
|
||||
struct MMPerKey {
|
||||
// Only profile if enabled and the main autotuner finished. `autotune_par_a`
|
||||
// might not be active if inputs are all BF16.
|
||||
bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); }
|
||||
|
||||
MMAutoTune<MMConfig> autotune;
|
||||
MMAutoTune<MMParA> autotune_par_a;
|
||||
};
|
||||
|
|
@ -666,12 +684,15 @@ struct MatMulEnv {
|
|||
// Whether to print the best config immediately after autotuning finished.
|
||||
bool print_best = false;
|
||||
|
||||
MMStorage storage;
|
||||
MMEntireA A_BF;
|
||||
MMTilesC C_tiles;
|
||||
|
||||
struct PerCluster {
|
||||
MMKeys keys;
|
||||
std::vector<MMPerKey> per_key;
|
||||
HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing
|
||||
// Prevents false sharing.
|
||||
HWY_MAYBE_UNUSED uint8_t
|
||||
padding[HWY_ALIGNMENT - sizeof(MMKeys) - sizeof(per_key)];
|
||||
};
|
||||
std::vector<PerCluster> per_cluster;
|
||||
|
||||
|
|
@ -687,31 +708,57 @@ struct MatMulEnv {
|
|||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
};
|
||||
|
||||
// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols)
|
||||
// that this thread has just filled, a view into a second tile (only for the
|
||||
// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`).
|
||||
using MMFused = std::function<void(RowPtrsBF, IndexRange, IndexRange,
|
||||
StridedViewBF, size_t)>;
|
||||
// Called via `CallClosure`, which consumes the first (opaque) argument. User
|
||||
// functions are called with the entire C matrix, the sub-ranges of M (rows)
|
||||
// and N (cols) that this thread has just filled, a view into a second tile
|
||||
// (only for `TwoMatmul`), and the worker thread index (see `ParallelFor`).
|
||||
typedef void (*MMFunc)(const void* opaque, RowPtrsBF, IndexRange, IndexRange,
|
||||
StridedViewBF, size_t);
|
||||
|
||||
class MMOptions {
|
||||
// Same technique as in `hwy::ThreadPool` and C++23 `std::function_ref`:
|
||||
// type-erasure without allocation.
|
||||
template <class Closure>
|
||||
static void CallClosure(const void* opaque, RowPtrsBF C1, IndexRange range_r,
|
||||
IndexRange range_c, StridedViewBF C2, size_t worker) {
|
||||
(*reinterpret_cast<const Closure*>(opaque))(C1, range_r, range_c, C2,
|
||||
worker);
|
||||
}
|
||||
|
||||
public:
|
||||
// `closure` must remain alive until the end of (Two)MatMul.
|
||||
template <class Closure>
|
||||
void SetFunc(const Closure& closure) {
|
||||
func = static_cast<MMFunc>(&CallClosure<Closure>);
|
||||
opaque = &closure;
|
||||
}
|
||||
|
||||
void MaybeCallFunc(RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
||||
StridedViewBF C2, size_t worker) const {
|
||||
if (func != nullptr) {
|
||||
func(opaque, C1, range_r, range_c, C2, worker);
|
||||
}
|
||||
}
|
||||
|
||||
MMFunc func = nullptr; // called if non-null and `TC` is BF16.
|
||||
const void* opaque = nullptr;
|
||||
|
||||
struct MMOptions {
|
||||
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
||||
|
||||
MMFused fused; // called if non-null and `TC` is BF16.
|
||||
};
|
||||
|
||||
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
||||
// register pressure compared to individual values/references. Also used for
|
||||
// passing through `DispatchOrder`.
|
||||
struct MMArgs {
|
||||
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale,
|
||||
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, float scale_A,
|
||||
const float* HWY_RESTRICT add, MMOptions options,
|
||||
const MMAutoTune<MMConfig>& autotune, const MMConfig& config)
|
||||
: env(env),
|
||||
line_bytes(env.ctx.cache_info.LineBytes()),
|
||||
|
||||
range_n(0, N),
|
||||
scale(scale),
|
||||
scale_A(scale_A),
|
||||
add(add),
|
||||
options(options),
|
||||
|
||||
|
|
@ -728,7 +775,8 @@ struct MMArgs {
|
|||
|
||||
// MatMul arguments:
|
||||
const IndexRange range_n; // entire N
|
||||
const double scale;
|
||||
// There can be two B, so do not yet multiply together the A and B scales.
|
||||
const float scale_A;
|
||||
const float* HWY_RESTRICT add;
|
||||
const MMOptions options;
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,14 @@ namespace HWY_NAMESPACE {
|
|||
// included from matmul_static_*.cc.
|
||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT
|
||||
|
||||
HWY_MAYBE_UNUSED void TwoMatMulStatic(const MatPtrT<BF16>& A, // NOLINT
|
||||
const MatPtrT<GEMMA_MATMUL_TB>& B1,
|
||||
const MatPtrT<GEMMA_MATMUL_TB>& B2,
|
||||
MatMulEnv& env, MatPtrT<BF16>& C,
|
||||
MMOptions options) {
|
||||
TwoMatMul(A, B1, B2, env, C, options);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
|
|||
|
|
@ -37,13 +37,19 @@
|
|||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||
MatPtrT<TC>& C, MMOptions options);
|
||||
|
||||
#define GEMMA_MATMUL_FOR_B(TB) \
|
||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, TB) \
|
||||
void TwoMatMulStatic(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1, \
|
||||
const MatPtrT<TB>& B2, MatMulEnv& env, \
|
||||
MatPtrT<BF16>& C, MMOptions options);
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
||||
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \
|
||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \
|
||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \
|
||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \
|
||||
GEMMA_MATMUL_FOR_B(BF16) \
|
||||
GEMMA_MATMUL_FOR_B(float) \
|
||||
GEMMA_MATMUL_FOR_B(NuqStream) \
|
||||
GEMMA_MATMUL_FOR_B(SfpStream) \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@
|
|||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "ops/matmul.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
|
|
@ -246,7 +248,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
|
||||
MatPadding::kOdd);
|
||||
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
||||
MatStorageT<TC> C2("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
||||
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||
|
||||
MatStorageT<float> add_storage =
|
||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
|
||||
|
|
@ -262,7 +266,48 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
for (size_t rep = 0; rep < 16; ++rep) {
|
||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
|
||||
AssertClose(A, BT, C_slow, C, env, line);
|
||||
if (per_key->autotune.Best()) break;
|
||||
// Check before TwoMatMulStatic(), which can invalidate per_key.
|
||||
const bool autotune_done = !!per_key->autotune.Best();
|
||||
|
||||
// Ensure the tiled view returns the same result as C.
|
||||
if constexpr (IsBF16<TA>() && IsBF16<TC>()) {
|
||||
// The total view area should match the entire C matrix.
|
||||
std::atomic<size_t> total_view_area = 0;
|
||||
|
||||
const auto fused = [&](RowPtrsBF C2_rows, IndexRange range_r,
|
||||
IndexRange range_c, StridedViewBF C2_view,
|
||||
size_t worker) {
|
||||
total_view_area.fetch_add(range_r.Num() * range_c.Num());
|
||||
HWY_ASSERT(range_c.Num() <= C2_view.Cols());
|
||||
HWY_ASSERT(worker < env.ctx.pools.MaxWorkers());
|
||||
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
|
||||
const size_t r = range_r.begin() + ir;
|
||||
for (size_t ic = 0; ic < range_c.Num(); ++ic) {
|
||||
const size_t c = range_c.begin() + ic;
|
||||
const float expected =
|
||||
hwy::ConvertScalarTo<float>(C2_rows.Row(r)[c]);
|
||||
const float actual =
|
||||
hwy::ConvertScalarTo<float>(C2_view.Row(ir)[ic]);
|
||||
const float L1 = hwy::ScalarAbs(actual - expected);
|
||||
if (L1 > 1E-6f) {
|
||||
HWY_ABORT("%zu: ir %zu ic %zu L1 %f expected %f actual %f.",
|
||||
worker, ir, ic, L1, expected, actual);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
options.SetFunc(fused);
|
||||
TwoMatMulStatic(A, BT, BT, env, C2, options);
|
||||
HWY_ASSERT_EQ(C.Extents().Area(), total_view_area.load());
|
||||
options.func = nullptr; // reset for next call
|
||||
|
||||
// TwoMatMulStatic() does not support adding a bias vector.
|
||||
if (!add) {
|
||||
AssertClose(A, BT, C, C2, env, line);
|
||||
}
|
||||
}
|
||||
|
||||
if (autotune_done) break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,6 +69,14 @@ MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
|||
});
|
||||
}
|
||||
|
||||
static inline void CallTwoMatMul(const MatPtrT<BF16>& A, const MatPtr& B1,
|
||||
const MatPtr& B2, MatMulEnv& env,
|
||||
MatPtrT<BF16>& C, const MMOptions& options) {
|
||||
return CallUpcastedSame(&B1, &B2, [&](const auto* B1_t, const auto* B2_t) {
|
||||
return TwoMatMulStatic(A, *B1_t, *B2_t, env, C, options);
|
||||
});
|
||||
}
|
||||
|
||||
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
|
||||
// casting prob from float to double just makes some changes to the
|
||||
// exponent bias and pads zeros in the mantissa.
|
||||
|
|
|
|||
13
util/mat.h
13
util/mat.h
|
|
@ -40,10 +40,11 @@ class RowPtrs {
|
|||
public:
|
||||
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
|
||||
|
||||
RowPtrs View(size_t r, size_t c) {
|
||||
// Extra argument is for compatibility with `StridedView`.
|
||||
RowPtrs View(size_t r, size_t c, size_t /*cols*/) {
|
||||
RowPtrs<T> view(row_ptrs_);
|
||||
view.r0_ = static_cast<uint32_t>(r);
|
||||
view.c0_ = static_cast<uint32_t>(c);
|
||||
view.r0_ = static_cast<uint32_t>(r0_ + r);
|
||||
view.c0_ = static_cast<uint32_t>(c0_ + c);
|
||||
return view;
|
||||
}
|
||||
|
||||
|
|
@ -531,7 +532,11 @@ class StridedView {
|
|||
: row0_(row0),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
stride_(static_cast<uint32_t>(stride)) {
|
||||
HWY_DASSERT(stride >= cols);
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (stride < cols) {
|
||||
HWY_ABORT("stride %zu < cols %zu", stride, cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
|
|
|
|||
Loading…
Reference in New Issue