mirror of https://github.com/google/gemma.cpp.git
1.03x speedup: fused FFN
matmul-inl: support CView=StridedView or RowPtrs; rename to C_MC_NC matmul.cc: Allow 1 more rep for MC/NC to allow half-sized tiles, which helps. PiperOrigin-RevId: 807291701
This commit is contained in:
parent
59db30e209
commit
f3bc1c17da
|
|
@ -37,7 +37,6 @@ class GemmaBatchBench : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
std::vector<std::string> BatchGemmaReply(
|
std::vector<std::string> BatchGemmaReply(
|
||||||
const std::vector<std::string>& inputs) {
|
const std::vector<std::string>& inputs) {
|
||||||
s_env->SetMaxGeneratedTokens(24);
|
|
||||||
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
s_env->MutableConfig().temperature = 0.0f; // deterministic
|
||||||
s_env->MutableConfig().verbosity = 2;
|
s_env->MutableConfig().verbosity = 2;
|
||||||
std::vector<std::string> replies;
|
std::vector<std::string> replies;
|
||||||
|
|
@ -92,15 +91,18 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||||
inputs.push_back(questions[qpos++]);
|
inputs.push_back(questions[qpos++]);
|
||||||
if (qpos == questions.size()) qpos = 0;
|
if (qpos == questions.size()) qpos = 0;
|
||||||
}
|
}
|
||||||
|
s_env->SetMaxGeneratedTokens(24);
|
||||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||||
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
||||||
++i) {
|
++i) {
|
||||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROFILER_PRINT_RESULTS();
|
||||||
|
|
||||||
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
||||||
// because those are already fast.
|
// because those are already fast.
|
||||||
s_env->SetMaxGeneratedTokens(3);
|
s_env->SetMaxGeneratedTokens(2);
|
||||||
responses = BatchGemmaReply(inputs);
|
responses = BatchGemmaReply(inputs);
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
PROFILER_PRINT_RESULTS();
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,10 @@ HWY_INLINE_VAR constexpr int kAttentionUseOld = 2;
|
||||||
|
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
|
HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024;
|
||||||
|
|
||||||
|
#ifndef GEMMA_FUSED_FFN
|
||||||
|
#define GEMMA_FUSED_FFN 1
|
||||||
|
#endif // !GEMMA_FUSED_FFN
|
||||||
|
|
||||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
enum class PromptWrapping {
|
enum class PromptWrapping {
|
||||||
GEMMA_IT,
|
GEMMA_IT,
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
// For use by Vit even if !GEMMA_FUSED_FFN.
|
||||||
template <typename T1, typename T2>
|
template <typename T1, typename T2>
|
||||||
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
||||||
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
const T2* HWY_RESTRICT c2, const size_t count, hwy::Profiler& p,
|
||||||
|
|
@ -64,7 +65,7 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// No C2 multiplier.
|
// No C2 multiplier - used by Vit.
|
||||||
template <class Mat>
|
template <class Mat>
|
||||||
void ActivationBatched(
|
void ActivationBatched(
|
||||||
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||||
|
|
@ -80,6 +81,34 @@ void ActivationBatched(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if GEMMA_FUSED_FFN
|
||||||
|
|
||||||
|
// Called during `TwoMatMul`.
|
||||||
|
static inline void Activation(ActivationType activation, const RowPtrsBF C1,
|
||||||
|
const IndexRange range_r,
|
||||||
|
const IndexRange range_c, const StridedViewBF C2,
|
||||||
|
hwy::Profiler& p, const size_t worker) {
|
||||||
|
static const auto zone = p.AddZone("Gen.ActivationFused");
|
||||||
|
PROFILER_ZONE3(p, worker, zone);
|
||||||
|
|
||||||
|
const size_t cols = range_c.Num();
|
||||||
|
HWY_DASSERT(C2.Cols() == cols);
|
||||||
|
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using DF = hn::ScalableTag<float>;
|
||||||
|
using VF = hn::Vec<DF>;
|
||||||
|
// ActivationType::Gelu
|
||||||
|
// Gated: Gelu(c1) * c2.
|
||||||
|
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
|
||||||
|
Decompress1AndCompressInplace(
|
||||||
|
DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir),
|
||||||
|
[](DF df, VF v1, VF v2)
|
||||||
|
HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
template <class Mat1, class Mat2>
|
template <class Mat1, class Mat2>
|
||||||
HWY_NOINLINE void ActivationBatched(
|
HWY_NOINLINE void ActivationBatched(
|
||||||
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||||
|
|
@ -102,6 +131,8 @@ HWY_NOINLINE void ActivationBatched(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // GEMMA_FUSED_FFN
|
||||||
|
|
||||||
template <typename T2, class LayerWeights>
|
template <typename T2, class LayerWeights>
|
||||||
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
|
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
|
||||||
MatPtrT<float>& HWY_RESTRICT x,
|
MatPtrT<float>& HWY_RESTRICT x,
|
||||||
|
|
@ -126,28 +157,32 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||||
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
|
env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive);
|
||||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
const size_t ffh_hidden_dim = layer_config.ff_hidden_dim;
|
|
||||||
|
|
||||||
const bool add_bias = layer_config.ff_biases;
|
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
|
||||||
const float* bias1 =
|
|
||||||
add_bias ? layer.ffw_gating_biases.PackedScale1() : nullptr;
|
|
||||||
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
|
|
||||||
const float* output_bias =
|
|
||||||
add_bias ? layer.ffw_output_biases.PackedScale1() : nullptr;
|
|
||||||
|
|
||||||
|
#if GEMMA_FUSED_FFN
|
||||||
|
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
||||||
|
StridedViewBF C2, size_t worker) {
|
||||||
|
Activation(layer_config.activation, C1, range_r, range_c, C2,
|
||||||
|
env.ctx.profiler, worker);
|
||||||
|
};
|
||||||
|
MMOptions options;
|
||||||
|
options.SetFunc(fused);
|
||||||
|
CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1,
|
||||||
|
layer.gating_einsum_w2, env, activations.C1, options);
|
||||||
|
#else
|
||||||
// Compute the hidden layer activations.
|
// Compute the hidden layer activations.
|
||||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, bias1, env,
|
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env,
|
||||||
activations.C1);
|
activations.C1);
|
||||||
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, bias2, env,
|
CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env,
|
||||||
activations.C2);
|
activations.C2);
|
||||||
|
|
||||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||||
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
ActivationBatched(layer_config.activation, activations.C1, &activations.C2,
|
||||||
env.ctx);
|
env.ctx);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
CallMatMul(activations.C1, layer.linear_w, output_bias, env,
|
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
||||||
activations.ffw_out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
398
ops/matmul-inl.h
398
ops/matmul-inl.h
|
|
@ -155,14 +155,14 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CView>
|
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CView>
|
||||||
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3,
|
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3,
|
||||||
const float scale, const float* HWY_RESTRICT add,
|
const float scale, const float* HWY_RESTRICT add,
|
||||||
const size_t imc, Tag tag, CView C_rows) const {
|
const size_t imc, Tag tag, CView C_MC_NR) const {
|
||||||
const V4 vscale = hn::Set(d4, scale);
|
const V4 vscale = hn::Set(d4, scale);
|
||||||
HWY_ALIGN static constexpr float kZero[4] = {};
|
HWY_ALIGN static constexpr float kZero[4] = {};
|
||||||
const V4 vadd = hn::Load(d4, add ? add : kZero);
|
const V4 vadd = hn::Load(d4, add ? add : kZero);
|
||||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows);
|
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_MC_NR);
|
||||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows);
|
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_MC_NR);
|
||||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows);
|
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_MC_NR);
|
||||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows);
|
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_MC_NR);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -202,10 +202,10 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
class Tag, class CView>
|
class Tag, class CView>
|
||||||
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||||
VF4 vadd, Tag, const size_t imc,
|
VF4 vadd, Tag, const size_t imc,
|
||||||
CView C_view) {
|
CView C_MC_NR) {
|
||||||
if constexpr (kRow < kRowsAC) {
|
if constexpr (kRow < kRowsAC) {
|
||||||
using TC = hwy::RemoveCvRef<decltype(C_view.Row(0)[0])>;
|
using TC = hwy::RemoveCvRef<decltype(C_MC_NR.Row(0)[0])>;
|
||||||
TC* HWY_RESTRICT pos = C_view.Row(imc + kRow);
|
TC* HWY_RESTRICT pos = C_MC_NR.Row(imc + kRow);
|
||||||
const hn::Rebind<TC, DF4> dc4;
|
const hn::Rebind<TC, DF4> dc4;
|
||||||
if constexpr (hwy::IsSame<Tag, MMAddC>()) {
|
if constexpr (hwy::IsSame<Tag, MMAddC>()) {
|
||||||
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
|
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
|
||||||
|
|
@ -268,9 +268,9 @@ class MMDecompress {
|
||||||
} else {
|
} else {
|
||||||
// Always decompress. To reduce code size/compile time, we no longer
|
// Always decompress. To reduce code size/compile time, we no longer
|
||||||
// support a separate F32 kernel; most A are already BF16. We also only
|
// support a separate F32 kernel; most A are already BF16. We also only
|
||||||
// have a single MMStorage.
|
// have a single MMEntireA.
|
||||||
HWY_ASSERT(options.cluster_idx == 0);
|
HWY_ASSERT(options.cluster_idx == 0);
|
||||||
const StridedViewBF A_view = env.storage.A(A.Extents());
|
const StridedViewBF A_view = env.A_BF.A(A.Extents());
|
||||||
AutotuneDecompressA(A, A_view, autotune, env, options);
|
AutotuneDecompressA(A, A_view, autotune, env, options);
|
||||||
return A_view;
|
return A_view;
|
||||||
}
|
}
|
||||||
|
|
@ -387,111 +387,52 @@ class MMDecompress {
|
||||||
|
|
||||||
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
|
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
|
||||||
class MMKernel {
|
class MMKernel {
|
||||||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
|
||||||
// allocation avoids passing a worker index.
|
|
||||||
static constexpr size_t B_stride_max =
|
|
||||||
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
|
||||||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
|
|
||||||
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
|
|
||||||
// at row/col 0. `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
|
|
||||||
// Called by B3A2C0 and by callers that hoist `A_view`.
|
|
||||||
template <class Tag, class CView>
|
|
||||||
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
|
||||||
const StridedViewBF B_view, size_t mr,
|
|
||||||
const IndexRange& range_mc, size_t kc,
|
|
||||||
const float scale, const float* HWY_RESTRICT add,
|
|
||||||
Tag tag, CView C_view) {
|
|
||||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
|
||||||
|
|
||||||
const size_t mc = range_mc.Num();
|
|
||||||
size_t imc = 0;
|
|
||||||
|
|
||||||
// M == 1, or x86 with 8 SIMD registers:
|
|
||||||
if (HWY_UNLIKELY(mr == 1)) {
|
|
||||||
for (; imc < mc; ++imc) {
|
|
||||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// AVX2 (16 registers)
|
|
||||||
if (HWY_UNLIKELY(mr == 2)) {
|
|
||||||
if (HWY_LIKELY(mc >= 2)) {
|
|
||||||
for (; imc <= mc - 2; imc += 2) {
|
|
||||||
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (HWY_UNLIKELY(imc != mc)) {
|
|
||||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_DASSERT(mr == 4);
|
|
||||||
if (HWY_LIKELY(mc >= 4)) {
|
|
||||||
for (; imc <= mc - 4; imc += 4) {
|
|
||||||
LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const size_t remainder_mc = mc - imc;
|
|
||||||
HWY_DASSERT(remainder_mc < 4);
|
|
||||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
|
||||||
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
|
||||||
imc += 2;
|
|
||||||
}
|
|
||||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
|
||||||
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
|
|
||||||
imc += 1;
|
|
||||||
}
|
|
||||||
HWY_DASSERT(imc == mc);
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr size_t B_storage_max = kNR * B_stride_max;
|
|
||||||
|
|
||||||
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
||||||
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
// `mc x kc` of A, `nc x kc` of B, and updates the `mc x nc` `C_MC_NC`.
|
||||||
// `ForeachKC` and when there is only a single KC task.
|
// `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
|
||||||
template <typename TB, typename TC, typename Tag>
|
template <typename TB, typename Tag, class CView>
|
||||||
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
|
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const IndexRange& range_mc, const IndexRange& range_kc,
|
const IndexRange& range_mc, const IndexRange& range_kc,
|
||||||
const IndexRange& range_nc, const MMArgs& args,
|
const IndexRange& range_nc, const MMArgs& args,
|
||||||
Tag out_tag, RowPtrs<TC> C) {
|
Tag out_tag, CView C_MC_NC) {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max];
|
|
||||||
|
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc);
|
const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc);
|
||||||
|
|
||||||
|
// Upper bound on per-worker storage for `kNR` row ranges of B. Stack
|
||||||
|
// allocation avoids passing a worker index.
|
||||||
|
constexpr size_t B_stride_max =
|
||||||
|
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
||||||
|
HWY_ALIGN BF16 B_storage[kNR * B_stride_max];
|
||||||
const size_t B_stride =
|
const size_t B_stride =
|
||||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes);
|
Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes);
|
||||||
const StridedViewBF B_storage_view(B_storage, kc, B_stride);
|
const StridedViewBF B_storage_view(B_storage, kc, B_stride);
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
const float scale = args.scale_A * B.Scale();
|
||||||
row_b += kNR) {
|
for (size_t inc = 0; inc < range_nc.Num(); inc += kNR) {
|
||||||
|
// For `add` and `B`, which are global, unlike `C_MC_NC`.
|
||||||
|
const size_t row_b = range_nc.begin() + inc;
|
||||||
const StridedViewBF B_view =
|
const StridedViewBF B_view =
|
||||||
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
|
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
|
||||||
const RowPtrs<TC> C_view = C.View(range_mc.begin(), row_b);
|
const CView C_MC_NR = C_MC_NC.View(0, inc, kNR);
|
||||||
const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr;
|
const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr;
|
||||||
A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag,
|
A2C0(A_view, B_view, args.mr, range_mc, kc, scale, add, out_tag, C_MC_NR);
|
||||||
C_view);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TB, typename TC>
|
template <typename TB, class CView>
|
||||||
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
|
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const IndexRange& range_mc,
|
const IndexRange& range_mc,
|
||||||
const IndexRangePartition& ranges_kc,
|
const IndexRangePartition& ranges_kc,
|
||||||
const IndexRange& range_nc, const MMArgs& args,
|
const IndexRange& range_nc, const MMArgs& args,
|
||||||
RowPtrs<TC> C) {
|
CView C_MC_NC) {
|
||||||
// Peel off the first iteration of the kc loop: avoid zero-initializing `C`
|
// Peel off the first iteration of the kc loop: avoid zero-initializing `C`
|
||||||
// by writing directly into it, and later accumulating into it.
|
// by writing directly into it, and later accumulating into it.
|
||||||
ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
|
ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
|
||||||
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C);
|
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C_MC_NC);
|
||||||
});
|
});
|
||||||
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
|
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
|
||||||
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C);
|
B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C_MC_NC);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -585,15 +526,15 @@ class MMKernel {
|
||||||
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
||||||
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
||||||
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
||||||
// Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0.
|
// Updates a `kRowsAC x kNR` tile in `C_MC_NR` starting at row `imc`, column
|
||||||
// `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
|
// 0. `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
|
||||||
// relative to the C column.
|
// relative to the C column.
|
||||||
template <size_t kRowsAC, /*deduced:*/ class Tag, class CView>
|
template <size_t kRowsAC, /*deduced:*/ class Tag, class CView>
|
||||||
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
||||||
const StridedViewBF B_view, size_t imc,
|
const StridedViewBF B_view, size_t imc,
|
||||||
size_t kc, const float scale,
|
size_t kc, const float scale,
|
||||||
const float* HWY_RESTRICT add, Tag tag,
|
const float* HWY_RESTRICT add, Tag tag,
|
||||||
CView C_view) {
|
CView C_MC_NR) {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
using VBF = hn::Vec<decltype(dbf)>;
|
using VBF = hn::Vec<decltype(dbf)>;
|
||||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
@ -777,7 +718,62 @@ class MMKernel {
|
||||||
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
|
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
|
||||||
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
|
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
|
||||||
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
|
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
|
||||||
horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view);
|
horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_MC_NR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||||
|
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
|
||||||
|
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
|
||||||
|
// at row/col 0.
|
||||||
|
template <class Tag, class CView>
|
||||||
|
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
||||||
|
const StridedViewBF B_view, size_t mr,
|
||||||
|
const IndexRange& range_mc, size_t kc,
|
||||||
|
const float scale, const float* HWY_RESTRICT add,
|
||||||
|
Tag tag, CView C_MC_NR) {
|
||||||
|
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||||
|
|
||||||
|
const size_t mc = range_mc.Num();
|
||||||
|
size_t imc = 0;
|
||||||
|
|
||||||
|
// M == 1, or x86 with 8 SIMD registers:
|
||||||
|
if (HWY_UNLIKELY(mr == 1)) {
|
||||||
|
for (; imc < mc; ++imc) {
|
||||||
|
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// AVX2 (16 registers)
|
||||||
|
if (HWY_UNLIKELY(mr == 2)) {
|
||||||
|
if (HWY_LIKELY(mc >= 2)) {
|
||||||
|
for (; imc <= mc - 2; imc += 2) {
|
||||||
|
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (HWY_UNLIKELY(imc != mc)) {
|
||||||
|
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_DASSERT(mr == 4);
|
||||||
|
if (HWY_LIKELY(mc >= 4)) {
|
||||||
|
for (; imc <= mc - 4; imc += 4) {
|
||||||
|
LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const size_t remainder_mc = mc - imc;
|
||||||
|
HWY_DASSERT(remainder_mc < 4);
|
||||||
|
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||||
|
LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||||
|
imc += 2;
|
||||||
|
}
|
||||||
|
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||||
|
LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_MC_NR);
|
||||||
|
imc += 1;
|
||||||
|
}
|
||||||
|
HWY_DASSERT(imc == mc);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -813,10 +809,10 @@ class MMImpl {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N,
|
static MMPerKey& FindOrAddPerKey(size_t M, size_t K, size_t N, size_t num_B,
|
||||||
size_t vector_bytes,
|
size_t vector_bytes,
|
||||||
MatMulEnv::PerCluster& per_cluster) {
|
MatMulEnv::PerCluster& per_cluster) {
|
||||||
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N, num_B);
|
||||||
intptr_t index = IndexOfKey(key, per_cluster.keys);
|
intptr_t index = IndexOfKey(key, per_cluster.keys);
|
||||||
// First time we see this shape/key.
|
// First time we see this shape/key.
|
||||||
if (HWY_UNLIKELY(index < 0)) {
|
if (HWY_UNLIKELY(index < 0)) {
|
||||||
|
|
@ -831,17 +827,19 @@ class MMImpl {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N,
|
static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N,
|
||||||
double t0, MMAutoTune<MMConfig>& tuner,
|
size_t num_B, double t0,
|
||||||
|
MMAutoTune<MMConfig>& tuner,
|
||||||
const MMConfig& cfg) {
|
const MMConfig& cfg) {
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
hwy::platform::InvariantTicksPerSecond();
|
hwy::platform::InvariantTicksPerSecond();
|
||||||
const double flops = 2 * M * K * N / min_elapsed; // * 2 for FMA
|
const double flops = 2 * M * K * N * num_B / min_elapsed; // * 2 for FMA
|
||||||
if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) {
|
if (HWY_UNLIKELY(env.print_measurement && tuner.ShouldPrint())) {
|
||||||
fprintf(stderr, "%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n", flops * 1E-9,
|
fprintf(stderr, "%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu\n",
|
||||||
min_elapsed * 1E3, cfg.MR(), cfg.MC(), cfg.KC(), cfg.NC(),
|
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, cfg.MR(),
|
||||||
StringFromOrder(cfg.Order()), cfg.InnerTasks());
|
cfg.MC(), cfg.KC(), cfg.NC(), StringFromOrder(cfg.Order()),
|
||||||
|
cfg.InnerTasks());
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(env.print_best && tuner.Best())) {
|
if (HWY_UNLIKELY(env.print_best && tuner.Best())) {
|
||||||
const auto ratio = [&tuner](uint64_t ticks) -> double {
|
const auto ratio = [&tuner](uint64_t ticks) -> double {
|
||||||
|
|
@ -849,9 +847,10 @@ class MMImpl {
|
||||||
static_cast<double>(tuner.BestTicks());
|
static_cast<double>(tuner.BestTicks());
|
||||||
};
|
};
|
||||||
const MMConfig& best = *tuner.Best();
|
const MMConfig& best = *tuner.Best();
|
||||||
fprintf(stderr,
|
fprintf(
|
||||||
"\n%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n",
|
stderr,
|
||||||
M, K, N, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
|
"\n%zu,%zu,%zu,%zu,%7.1f,%.2f,%zu,%4zu,%4zu,%5zu,%s,%zu,%.2f,%.2f\n",
|
||||||
|
M, K, N, num_B, flops * 1E-9, min_elapsed * 1E3, best.MR(), best.MC(),
|
||||||
best.KC(), best.NC(), StringFromOrder(best.Order()),
|
best.KC(), best.NC(), StringFromOrder(best.Order()),
|
||||||
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
|
best.InnerTasks(), ratio(tuner.WorstMinTicks()),
|
||||||
ratio(tuner.FirstConfigTicks()));
|
ratio(tuner.FirstConfigTicks()));
|
||||||
|
|
@ -874,10 +873,11 @@ class MMImpl {
|
||||||
class MMLoops {
|
class MMLoops {
|
||||||
public:
|
public:
|
||||||
// Called from `MatMul` from two places: either with the next autotune config,
|
// Called from `MatMul` from two places: either with the next autotune config,
|
||||||
// or with the best config.
|
// or with the best config. `B2` is null unless called from `TwoMatMul`.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C, const MMArgs& args) {
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
||||||
PROFILER_ZONE3(args.env.ctx.profiler,
|
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||||
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
args.env.ctx.Worker(args.options.cluster_idx), zone);
|
||||||
|
|
@ -885,7 +885,7 @@ class MMLoops {
|
||||||
DispatchParallelism(
|
DispatchParallelism(
|
||||||
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||||
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
|
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
|
||||||
Loop(order, parallel, A, B, C, args);
|
Loop(order, parallel, A, B, B2, C, args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -901,18 +901,14 @@ class MMLoops {
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C, const MMArgs& args) {
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
|
||||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
const IndexRange& range_M = args.ranges_mc.Range(0);
|
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||||
const IndexRange& range_K = args.ranges_kc.Range(0);
|
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||||
const size_t K = range_K.Num();
|
|
||||||
const StridedViewBF A_view = A.View(range_M.begin(), 0, K);
|
|
||||||
const size_t B_stride =
|
|
||||||
Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes);
|
|
||||||
|
|
||||||
// Similar to `B3A2C0`, but here we hoisted `A_view`.
|
|
||||||
parallel.ForN(
|
parallel.ForN(
|
||||||
args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes),
|
args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes),
|
||||||
args.inner_tasks, args.options.cluster_idx,
|
args.inner_tasks, args.options.cluster_idx,
|
||||||
|
|
@ -920,26 +916,19 @@ class MMLoops {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
|
|
||||||
HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS
|
MMKernel::B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(),
|
||||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
C.View(0, range_nc.begin(), range_nc.Num()));
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||||
row_b += kNR) {
|
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||||
const StridedViewBF B_view =
|
|
||||||
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
|
|
||||||
const RowPtrs<TC> C_view = C.View(range_M.begin(), row_b);
|
|
||||||
const float* HWY_RESTRICT add =
|
|
||||||
args.add ? args.add + row_b : nullptr;
|
|
||||||
|
|
||||||
MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add,
|
if (B2 != nullptr) {
|
||||||
MMSetC(), C_view);
|
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
|
||||||
|
MMSetC(), C2);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (IsBF16<TC>()) {
|
if constexpr (IsBF16<TC>()) {
|
||||||
if (args.options.fused) {
|
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
|
||||||
StridedViewBF C2(nullptr, 0, 0);
|
|
||||||
args.options.fused(C, range_M, range_nc, C2, worker);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -948,7 +937,8 @@ class MMLoops {
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C, const MMArgs& args) {
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
|
||||||
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||||
|
|
@ -959,14 +949,21 @@ class MMLoops {
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc,
|
MMKernel::ForeachKC(
|
||||||
range_nc, args, C);
|
A, B, range_mc, args.ranges_kc, range_nc, args,
|
||||||
|
C.View(0, range_nc.begin(), range_nc.Num()));
|
||||||
|
|
||||||
|
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||||
|
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||||
|
|
||||||
|
if (B2 != nullptr) {
|
||||||
|
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc,
|
||||||
|
range_nc, args, C2);
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (IsBF16<TC>()) {
|
if constexpr (IsBF16<TC>()) {
|
||||||
if (args.options.fused) {
|
args.options.MaybeCallFunc(C, range_mc, range_nc, C2,
|
||||||
StridedViewBF C2(nullptr, 0, 0);
|
worker);
|
||||||
args.options.fused(C, range_mc, range_nc, C2, worker);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -976,10 +973,11 @@ class MMLoops {
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C, const MMArgs& args) {
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
|
||||||
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
const IndexRange& range_K = args.ranges_kc.Range(0);
|
const IndexRange& range_kc = args.ranges_kc.Range(0);
|
||||||
|
|
||||||
parallel.ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
||||||
|
|
@ -987,14 +985,19 @@ class MMLoops {
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(),
|
MMKernel::B3A2C0(
|
||||||
C);
|
A, B, range_mc, range_kc, range_nc, args, MMSetC(),
|
||||||
|
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));
|
||||||
|
|
||||||
if constexpr (IsBF16<TC>()) {
|
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||||
if (args.options.fused) {
|
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||||
StridedViewBF C2(nullptr, 0, 0);
|
|
||||||
args.options.fused(C, range_mc, range_nc, C2, worker);
|
if (B2 != nullptr) {
|
||||||
|
MMKernel::B3A2C0(A, *B2, range_mc, range_kc, range_nc, args,
|
||||||
|
MMSetC(), C2);
|
||||||
}
|
}
|
||||||
|
if constexpr (IsBF16<TC>()) {
|
||||||
|
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -1004,7 +1007,8 @@ class MMLoops {
|
||||||
template <typename TB, typename TC, class Parallel>
|
template <typename TB, typename TC, class Parallel>
|
||||||
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
|
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C, const MMArgs& args) {
|
const MatPtrT<TB>* B2, RowPtrs<TC> C,
|
||||||
|
const MMArgs& args) {
|
||||||
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
|
||||||
|
|
||||||
parallel.ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
|
|
@ -1013,14 +1017,20 @@ class MMLoops {
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args,
|
MMKernel::ForeachKC(
|
||||||
C);
|
A, B, range_mc, args.ranges_kc, range_nc, args,
|
||||||
|
C.View(range_mc.begin(), range_nc.begin(), range_nc.Num()));
|
||||||
|
|
||||||
|
const StridedViewBF C2 = args.env.C_tiles.C(
|
||||||
|
Extents2D(range_mc.Num(), range_nc.Num()), worker);
|
||||||
|
|
||||||
|
if (B2 != nullptr) {
|
||||||
|
MMKernel::ForeachKC(A, *B2, range_mc, args.ranges_kc, range_nc,
|
||||||
|
args, C2);
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (IsBF16<TC>()) {
|
if constexpr (IsBF16<TC>()) {
|
||||||
if (args.options.fused) {
|
args.options.MaybeCallFunc(C, range_mc, range_nc, C2, worker);
|
||||||
StridedViewBF C2(nullptr, 0, 0);
|
|
||||||
args.options.fused(C, range_mc, range_nc, C2, worker);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -1060,20 +1070,23 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const size_t M = A.Rows();
|
const size_t M = A.Rows();
|
||||||
const size_t K = A.Cols();
|
const size_t K = A.Cols();
|
||||||
const size_t N = B.Rows();
|
const size_t N = B.Rows();
|
||||||
|
const size_t num_B = 1;
|
||||||
|
|
||||||
const CacheInfo& cache = env.ctx.cache_info;
|
const CacheInfo& cache = env.ctx.cache_info;
|
||||||
MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(),
|
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
|
||||||
env.per_cluster[cluster_idx]);
|
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
|
||||||
|
|
||||||
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
||||||
const StridedViewBF A_view =
|
const StridedViewBF A_view =
|
||||||
MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
||||||
|
|
||||||
|
MatPtrT<TB>* B2 = nullptr; // required for type matching
|
||||||
|
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
|
const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner,
|
||||||
add, options, tuner, *tuner.Best());
|
*tuner.Best());
|
||||||
MMLoops::Dispatch(A_view, B, C_rows, args);
|
MMLoops::Dispatch(A_view, B, B2, C_rows, args);
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1082,20 +1095,83 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
// Ensure matrix dimensions match each other (off the hot path).
|
// Ensure matrix dimensions match each other (off the hot path).
|
||||||
HWY_ASSERT(K == B.Cols());
|
HWY_ASSERT(K == B.Cols());
|
||||||
HWY_ASSERT(M <= kMaxBatchSize);
|
HWY_ASSERT(M <= kMaxBatchSize);
|
||||||
HWY_ASSERT(K <= MMStorage::kMaxK);
|
HWY_ASSERT(K <= MMEntireA::kMaxK);
|
||||||
HWY_ASSERT(N % kNR == 0);
|
HWY_ASSERT(N % kNR == 0);
|
||||||
MMImpl::EnsureAligned(A, cache.VectorBytes());
|
MMImpl::EnsureAligned(A, cache.VectorBytes());
|
||||||
tuner.SetCandidates(
|
tuner.SetCandidates(
|
||||||
MMCandidates(cache, M, K, N, sizeof(TC), env.print_config));
|
MMCandidates(cache, M, K, N, num_B, sizeof(TC), env.print_config));
|
||||||
}
|
}
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
|
const MMArgs args(env, M, K, N, A.Scale(), add, options, tuner, cfg);
|
||||||
add, options, tuner, cfg);
|
|
||||||
|
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMLoops::Dispatch(A_view, B, C_rows, args);
|
MMLoops::Dispatch(A_view, B, B2, C_rows, args);
|
||||||
MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg);
|
MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg);
|
||||||
|
|
||||||
|
return &per_key;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs A*B1 and A*B2 in parallel. This is useful for gated FFNs.
|
||||||
|
// Differences vs MatMul: The second result matrix is not materialized, it is
|
||||||
|
// only passed to the `options.func` callback. There is no `add` argument
|
||||||
|
// because it is not required for this use case. There is no default `options`
|
||||||
|
// argument because `options.func` must be set by the caller.
|
||||||
|
template <typename TB>
|
||||||
|
HWY_NOINLINE MMPerKey* TwoMatMul(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1,
|
||||||
|
const MatPtrT<TB>& B2, MatMulEnv& env,
|
||||||
|
MatPtrT<BF16>& C, MMOptions options) {
|
||||||
|
static const auto zone = env.ctx.profiler.AddZone("MM.TwoMatMul");
|
||||||
|
const size_t cluster_idx = options.cluster_idx;
|
||||||
|
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||||
|
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
|
||||||
|
|
||||||
|
HWY_DASSERT(options.func != nullptr); // no other way to get access to C2.
|
||||||
|
|
||||||
|
RowPtrs<BF16> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
||||||
|
|
||||||
|
const size_t M = A.Rows();
|
||||||
|
const size_t K = A.Cols();
|
||||||
|
const size_t N = B1.Rows();
|
||||||
|
const size_t num_B = 2;
|
||||||
|
|
||||||
|
const CacheInfo& cache = env.ctx.cache_info;
|
||||||
|
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
|
||||||
|
M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]);
|
||||||
|
|
||||||
|
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
||||||
|
const StridedViewBF A_view(A, 0, 0, A.Cols());
|
||||||
|
|
||||||
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
|
// Only A scale - B1/B2 may differ, and are passed separately.
|
||||||
|
const MMArgs args(env, M, K, N, A.Scale(),
|
||||||
|
/*add=*/nullptr, options, tuner, *tuner.Best());
|
||||||
|
MMLoops::Dispatch(A_view, B1, &B2, C_rows, args);
|
||||||
|
return &per_key;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Autotuning, first call: enumerate all feasible configs.
|
||||||
|
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
|
||||||
|
// Ensure matrix dimensions match each other (off the hot path).
|
||||||
|
HWY_ASSERT(K == B1.Cols());
|
||||||
|
HWY_ASSERT(K == B2.Cols());
|
||||||
|
HWY_ASSERT(M <= kMaxBatchSize);
|
||||||
|
HWY_ASSERT(K <= MMEntireA::kMaxK);
|
||||||
|
HWY_ASSERT(N % kNR == 0);
|
||||||
|
MMImpl::EnsureAligned(A, cache.VectorBytes());
|
||||||
|
tuner.SetCandidates(
|
||||||
|
MMCandidates(cache, M, K, N, num_B, sizeof(BF16), env.print_config));
|
||||||
|
}
|
||||||
|
|
||||||
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
|
// Only A scale - B1/B2 may differ, and are passed separately.
|
||||||
|
const MMArgs args(env, M, K, N, A.Scale(), /*add=*/nullptr, options, tuner,
|
||||||
|
cfg);
|
||||||
|
|
||||||
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
|
MMLoops::Dispatch(A_view, B1, &B2, C_rows, args);
|
||||||
|
MMImpl::NotifyAutotuneResult(env, M, K, N, num_B, t0, tuner, cfg);
|
||||||
|
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,11 +63,12 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
||||||
class GenerateCandidates {
|
class GenerateCandidates {
|
||||||
public:
|
public:
|
||||||
GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N,
|
GenerateCandidates(const CacheInfo& cache, size_t M, size_t K, size_t N,
|
||||||
size_t sizeof_TC, bool print_config)
|
size_t num_B, size_t sizeof_TC, bool print_config)
|
||||||
: cache_(cache),
|
: cache_(cache),
|
||||||
M_(M),
|
M_(M),
|
||||||
K_(K),
|
K_(K),
|
||||||
N_(N),
|
N_(N),
|
||||||
|
num_B_(num_B),
|
||||||
sizeof_TC_(sizeof_TC),
|
sizeof_TC_(sizeof_TC),
|
||||||
// These influence kc/nc, but are also stored in `MMConfig` for
|
// These influence kc/nc, but are also stored in `MMConfig` for
|
||||||
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
||||||
|
|
@ -150,7 +151,7 @@ class GenerateCandidates {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The number of A and B columns to read between updating `partial`.
|
// The number of A and B columns to read between updating `C`.
|
||||||
SizeVec KC(size_t mr, MMOrder order) const {
|
SizeVec KC(size_t mr, MMOrder order) const {
|
||||||
// `LoopKC` handles up to `mr` rows of A.
|
// `LoopKC` handles up to `mr` rows of A.
|
||||||
const size_t rows_a = HWY_MIN(M_, mr);
|
const size_t rows_a = HWY_MIN(M_, mr);
|
||||||
|
|
@ -164,7 +165,7 @@ class GenerateCandidates {
|
||||||
// TB=NUQ due to less amortization of the table loads. Due to the low L1
|
// TB=NUQ due to less amortization of the table loads. Due to the low L1
|
||||||
// latency, the packing is still effectively fused into `LoopKC`. It may
|
// latency, the packing is still effectively fused into `LoopKC`. It may
|
||||||
// be better to round up and accept a few L2 accesses in exchange for
|
// be better to round up and accept a few L2 accesses in exchange for
|
||||||
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
|
// fewer loops over K, and thus fewer writes to `C`. Hence we do not
|
||||||
// subtract the output and buf, and allow using more than the actual L1
|
// subtract the output and buf, and allow using more than the actual L1
|
||||||
// size. This results in an overestimate, and the loop below will propose
|
// size. This results in an overestimate, and the loop below will propose
|
||||||
// the next few smaller values for the autotuner to evaluate.
|
// the next few smaller values for the autotuner to evaluate.
|
||||||
|
|
@ -179,7 +180,7 @@ class GenerateCandidates {
|
||||||
|
|
||||||
// Avoid proposing kc > K.
|
// Avoid proposing kc > K.
|
||||||
if (K_ > kc_multiple_) {
|
if (K_ > kc_multiple_) {
|
||||||
// Generally it is best to use the full `kc` (fewer writes to `partial`),
|
// Generally it is best to use the full `kc` (fewer writes to `C`),
|
||||||
// but a bit less can be better if it evenly divides `K`, or enables an
|
// but a bit less can be better if it evenly divides `K`, or enables an
|
||||||
// `mc` that evenly divides `M`. Try several smaller values.
|
// `mc` that evenly divides `M`. Try several smaller values.
|
||||||
|
|
||||||
|
|
@ -196,7 +197,7 @@ class GenerateCandidates {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (print_config_ && all_kc.size() > 1) {
|
if (print_config_ && all_kc.size() > 1) {
|
||||||
fprintf(stderr, "KC: ");
|
fprintf(stderr, "num_B %zu: KC: ", num_B_);
|
||||||
for (size_t kc : all_kc) {
|
for (size_t kc : all_kc) {
|
||||||
fprintf(stderr, "%zu ", kc);
|
fprintf(stderr, "%zu ", kc);
|
||||||
}
|
}
|
||||||
|
|
@ -214,18 +215,18 @@ class GenerateCandidates {
|
||||||
|
|
||||||
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
||||||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||||
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` C rows.
|
||||||
// partial.
|
|
||||||
const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes();
|
const size_t bytes_per_mc = kc * sizeof(BF16) + cache_.LineBytes();
|
||||||
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
|
size_t mc_max = hwy::DivCeil(cache_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||||
mc_max = HWY_MIN(mc_max, kMaxBatchSize);
|
mc_max = HWY_MIN(mc_max, HWY_MIN(kMaxBatchSize, kMaxMC));
|
||||||
HWY_DASSERT(mc_max != 0);
|
HWY_DASSERT(mc_max != 0);
|
||||||
mc_max = HWY_MIN(mc_max, M_);
|
mc_max = HWY_MIN(mc_max, M_);
|
||||||
mc_max = hwy::RoundDownTo(mc_max, mr);
|
mc_max = hwy::RoundDownTo(mc_max, mr);
|
||||||
|
|
||||||
SizeVec all_mc(1, mc_max);
|
SizeVec all_mc(1, mc_max);
|
||||||
// Larger MC is better for non-blocks, otherwise we want more small options.
|
// Larger MC is better for non-blocks, otherwise we want more small options,
|
||||||
const size_t reps = !IsBlock(order) ? 2 : 3;
|
// especially for two B.
|
||||||
|
const size_t reps = !IsBlock(order) ? 2 : (2 + num_B_);
|
||||||
|
|
||||||
size_t prev = mc_max;
|
size_t prev = mc_max;
|
||||||
for (size_t rep = 0; rep < reps; ++rep) {
|
for (size_t rep = 0; rep < reps; ++rep) {
|
||||||
|
|
@ -240,7 +241,7 @@ class GenerateCandidates {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (print_config_ && all_mc.size() > 1) {
|
if (print_config_ && all_mc.size() > 1) {
|
||||||
fprintf(stderr, "MC: ");
|
fprintf(stderr, "num_B %zu: MC: ", num_B_);
|
||||||
for (size_t mc : all_mc) {
|
for (size_t mc : all_mc) {
|
||||||
fprintf(stderr, "%zu ", mc);
|
fprintf(stderr, "%zu ", mc);
|
||||||
}
|
}
|
||||||
|
|
@ -252,14 +253,15 @@ class GenerateCandidates {
|
||||||
|
|
||||||
// The number of (possibly L3 resident) B rows per `NT_MT` task.
|
// The number of (possibly L3 resident) B rows per `NT_MT` task.
|
||||||
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
|
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
|
||||||
size_t nc_max = N_;
|
size_t nc_max = kMaxNC;
|
||||||
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
|
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
|
||||||
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
|
// such that `nc x kc` of B and `mc x nc` of `C` fit in L3. Otherwise,
|
||||||
// Otherwise, leave it unbounded.
|
// leave it unbounded.
|
||||||
if (M_ > mr) {
|
if (M_ > mr) {
|
||||||
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
|
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
|
||||||
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), N_);
|
nc_max = HWY_MIN(hwy::DivCeil(cache_.L3Bytes(), bytes_per_nc), kMaxNC);
|
||||||
}
|
}
|
||||||
|
nc_max = HWY_MIN(nc_max, N_);
|
||||||
HWY_DASSERT(nc_max != 0);
|
HWY_DASSERT(nc_max != 0);
|
||||||
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
|
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
|
||||||
|
|
||||||
|
|
@ -278,7 +280,7 @@ class GenerateCandidates {
|
||||||
if (N_ > nc_multiple_) {
|
if (N_ > nc_multiple_) {
|
||||||
// Large L3, but its behavior and characteristics varies across platforms,
|
// Large L3, but its behavior and characteristics varies across platforms,
|
||||||
// hence autotune a wider range of nc than the other dimensions.
|
// hence autotune a wider range of nc than the other dimensions.
|
||||||
size_t reps = 10;
|
size_t reps = 9 + num_B_;
|
||||||
// For small M, we can afford larger NC, hence allow fewer small options.
|
// For small M, we can afford larger NC, hence allow fewer small options.
|
||||||
if (M_ <= 2 * mr) reps -= 1;
|
if (M_ <= 2 * mr) reps -= 1;
|
||||||
|
|
||||||
|
|
@ -301,7 +303,7 @@ class GenerateCandidates {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (print_config_ && all_nc.size() > 1) {
|
if (print_config_ && all_nc.size() > 1) {
|
||||||
fprintf(stderr, "NC: ");
|
fprintf(stderr, "num_B %zu: NC: ", num_B_);
|
||||||
for (size_t nc : all_nc) {
|
for (size_t nc : all_nc) {
|
||||||
fprintf(stderr, "%zu ", nc);
|
fprintf(stderr, "%zu ", nc);
|
||||||
}
|
}
|
||||||
|
|
@ -329,6 +331,7 @@ class GenerateCandidates {
|
||||||
const size_t M_;
|
const size_t M_;
|
||||||
const size_t K_;
|
const size_t K_;
|
||||||
const size_t N_;
|
const size_t N_;
|
||||||
|
const size_t num_B_;
|
||||||
const size_t sizeof_TC_;
|
const size_t sizeof_TC_;
|
||||||
|
|
||||||
const size_t kc_multiple_;
|
const size_t kc_multiple_;
|
||||||
|
|
@ -341,12 +344,13 @@ class GenerateCandidates {
|
||||||
|
|
||||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||||
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
||||||
size_t N, size_t sizeof_TC,
|
size_t N, size_t num_B, size_t sizeof_TC,
|
||||||
bool print_config) {
|
bool print_config) {
|
||||||
return GenerateCandidates(cache, M, K, N, sizeof_TC, print_config)();
|
return GenerateCandidates(cache, M, K, N, num_B, sizeof_TC, print_config)();
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx.allocator) {
|
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
||||||
|
: ctx(ctx), A_BF(ctx.allocator), C_tiles(ctx) {
|
||||||
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
|
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
|
||||||
per_cluster.resize(num_clusters);
|
per_cluster.resize(num_clusters);
|
||||||
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
|
||||||
|
|
|
||||||
126
ops/matmul.h
126
ops/matmul.h
|
|
@ -21,7 +21,6 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
|
|
@ -54,7 +53,9 @@ HWY_INLINE_VAR constexpr size_t kNR = 4;
|
||||||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
|
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
|
||||||
|
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink?
|
// For `MMTilesC`.
|
||||||
|
HWY_INLINE_VAR constexpr size_t kMaxMC = 512;
|
||||||
|
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
|
||||||
|
|
||||||
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||||
|
|
@ -108,9 +109,9 @@ struct MMParallelWithinCluster {
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t base = ctx.Worker(cluster_idx);
|
const size_t base = ctx.Worker(cluster_idx);
|
||||||
|
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition ranges_n = StaticPartition(
|
||||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||||
ParallelizeOneRange(worker_ranges, cluster,
|
ParallelizeOneRange(ranges_n, cluster,
|
||||||
[&](const IndexRange& worker_range, size_t worker) {
|
[&](const IndexRange& worker_range, size_t worker) {
|
||||||
func(worker_range, base + worker);
|
func(worker_range, base + worker);
|
||||||
});
|
});
|
||||||
|
|
@ -169,20 +170,20 @@ struct MMParallelHierarchical {
|
||||||
if (num_clusters == 1) {
|
if (num_clusters == 1) {
|
||||||
const size_t cluster_idx = 0;
|
const size_t cluster_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition ranges_n = StaticPartition(
|
||||||
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
|
||||||
return ParallelizeOneRange(
|
return ParallelizeOneRange(
|
||||||
worker_ranges, cluster,
|
ranges_n, cluster,
|
||||||
[&](const IndexRange& worker_range, size_t worker) {
|
[&](const IndexRange& worker_range, size_t worker) {
|
||||||
func(worker_range, worker);
|
func(worker_range, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assign each cluster a sub-range of `range_n` (typically hundreds).
|
// Assign each cluster a sub-range of `range_n` (typically hundreds).
|
||||||
const IndexRangePartition n_ranges =
|
const IndexRangePartition ranges_n =
|
||||||
StaticPartition(range_n, num_clusters, n_multiple);
|
StaticPartition(range_n, num_clusters, n_multiple);
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
n_ranges, all_clusters,
|
ranges_n, all_clusters,
|
||||||
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
[&](const IndexRange& n_range, const size_t cluster_idx) {
|
||||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t cluster_base = ctx.Worker(cluster_idx);
|
const size_t cluster_base = ctx.Worker(cluster_idx);
|
||||||
|
|
@ -274,32 +275,51 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
|
||||||
// C is BF16/float.
|
// C is BF16/float.
|
||||||
void BindC(ThreadingContext& ctx, MatPtr& C);
|
void BindC(ThreadingContext& ctx, MatPtr& C);
|
||||||
|
|
||||||
// For A.
|
// Space for converting A=F32 to BF16 before the matmul. This is faster than
|
||||||
class MMStorage {
|
// on-the-fly when native BF16 is available: it only happens once, not per B
|
||||||
|
// tile row, and the cache footprint is smaller.
|
||||||
|
class MMEntireA {
|
||||||
public:
|
public:
|
||||||
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
||||||
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
|
// and reusing it across `MatMul` calls. Sufficient for Gemma 2 27B.
|
||||||
static constexpr size_t kMaxK = 36 * 1024;
|
static constexpr size_t kMaxK = 36 * 1024;
|
||||||
|
|
||||||
MMStorage(const Allocator& allocator)
|
explicit MMEntireA(const Allocator& allocator)
|
||||||
// 288 MiB. Must be padded, see `DoDecompressA`.
|
// 288 MiB. Must be padded, see `DoDecompressA`.
|
||||||
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
|
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
|
||||||
MatPadding::kOdd) {}
|
MatPadding::kOdd) {}
|
||||||
|
|
||||||
// Returns matrix view. Converting A=F32 to BF16 up-front is faster than
|
|
||||||
// on-the-fly when native BF16 is available: it only happens once, not per B
|
|
||||||
// tile row, and the cache footprint is smaller.
|
|
||||||
StridedViewBF A(const Extents2D& extents) const {
|
StridedViewBF A(const Extents2D& extents) const {
|
||||||
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
||||||
HWY_DASSERT(extents.cols <= kMaxK);
|
return StridedViewBF(A_, 0, 0, extents.cols);
|
||||||
return StridedViewBF(const_cast<BF16*>(A_.Row(0)), extents.cols,
|
|
||||||
A_.Stride());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MatStorageT<BF16> A_;
|
MatStorageT<BF16> A_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// One tile of C per *worker* (required for `kNT_MT*`).
|
||||||
|
class MMTilesC {
|
||||||
|
public:
|
||||||
|
explicit MMTilesC(const ThreadingContext& ctx) {
|
||||||
|
const size_t max_workers = ctx.pools.MaxWorkers();
|
||||||
|
C_.reserve(max_workers);
|
||||||
|
for (size_t worker = 0; worker < max_workers; ++worker) {
|
||||||
|
C_.push_back(MatStorageT<BF16>("Ctile", Extents2D(kMaxBatchSize, kMaxNC),
|
||||||
|
ctx.allocator, MatPadding::kOdd));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StridedViewBF C(const Extents2D& extents, size_t worker) const {
|
||||||
|
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
||||||
|
HWY_DASSERT(worker < C_.size());
|
||||||
|
return StridedViewBF(C_[worker], 0, 0, extents.cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<MatStorageT<BF16>> C_;
|
||||||
|
};
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
// Autotuning
|
// Autotuning
|
||||||
|
|
||||||
|
|
@ -471,7 +491,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
std::vector<MMConfig> MMCandidates(const CacheInfo& cache, size_t M, size_t K,
|
||||||
size_t N, size_t sizeof_TC,
|
size_t N, size_t num_B, size_t sizeof_TC,
|
||||||
bool print_config);
|
bool print_config);
|
||||||
|
|
||||||
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
|
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
|
||||||
|
|
@ -595,12 +615,14 @@ class MMKeys {
|
||||||
static constexpr Key kPadding = 0;
|
static constexpr Key kPadding = 0;
|
||||||
|
|
||||||
// Compresses the dimensions into a single Key for faster comparison.
|
// Compresses the dimensions into a single Key for faster comparison.
|
||||||
static Key KeyFromDims(size_t M, size_t K, size_t N) {
|
static Key KeyFromDims(size_t M, size_t K, size_t N, size_t num_B) {
|
||||||
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
||||||
HWY_DASSERT(K < (Key{1} << 24));
|
HWY_DASSERT(K < (Key{1} << 20));
|
||||||
HWY_DASSERT(N < (Key{1} << 24));
|
HWY_DASSERT(N < (Key{1} << 20));
|
||||||
|
HWY_DASSERT(num_B == 1 || num_B == 2);
|
||||||
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
|
const Key key = static_cast<Key>(BucketM(M)) | (static_cast<Key>(K) << 16) |
|
||||||
(static_cast<Key>(N) << 40);
|
(static_cast<Key>(N) << 40) |
|
||||||
|
(static_cast<Key>(num_B) << 60);
|
||||||
HWY_DASSERT(key != kPadding);
|
HWY_DASSERT(key != kPadding);
|
||||||
return key;
|
return key;
|
||||||
}
|
}
|
||||||
|
|
@ -643,10 +665,6 @@ class MMKeys {
|
||||||
|
|
||||||
// Per-MatMul-shape state.
|
// Per-MatMul-shape state.
|
||||||
struct MMPerKey {
|
struct MMPerKey {
|
||||||
// Only profile if enabled and the main autotuner finished. `autotune_par_a`
|
|
||||||
// might not be active if inputs are all BF16.
|
|
||||||
bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); }
|
|
||||||
|
|
||||||
MMAutoTune<MMConfig> autotune;
|
MMAutoTune<MMConfig> autotune;
|
||||||
MMAutoTune<MMParA> autotune_par_a;
|
MMAutoTune<MMParA> autotune_par_a;
|
||||||
};
|
};
|
||||||
|
|
@ -666,12 +684,15 @@ struct MatMulEnv {
|
||||||
// Whether to print the best config immediately after autotuning finished.
|
// Whether to print the best config immediately after autotuning finished.
|
||||||
bool print_best = false;
|
bool print_best = false;
|
||||||
|
|
||||||
MMStorage storage;
|
MMEntireA A_BF;
|
||||||
|
MMTilesC C_tiles;
|
||||||
|
|
||||||
struct PerCluster {
|
struct PerCluster {
|
||||||
MMKeys keys;
|
MMKeys keys;
|
||||||
std::vector<MMPerKey> per_key;
|
std::vector<MMPerKey> per_key;
|
||||||
HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing
|
// Prevents false sharing.
|
||||||
|
HWY_MAYBE_UNUSED uint8_t
|
||||||
|
padding[HWY_ALIGNMENT - sizeof(MMKeys) - sizeof(per_key)];
|
||||||
};
|
};
|
||||||
std::vector<PerCluster> per_cluster;
|
std::vector<PerCluster> per_cluster;
|
||||||
|
|
||||||
|
|
@ -687,31 +708,57 @@ struct MatMulEnv {
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols)
|
// Called via `CallClosure`, which consumes the first (opaque) argument. User
|
||||||
// that this thread has just filled, a view into a second tile (only for the
|
// functions are called with the entire C matrix, the sub-ranges of M (rows)
|
||||||
// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`).
|
// and N (cols) that this thread has just filled, a view into a second tile
|
||||||
using MMFused = std::function<void(RowPtrsBF, IndexRange, IndexRange,
|
// (only for `TwoMatmul`), and the worker thread index (see `ParallelFor`).
|
||||||
StridedViewBF, size_t)>;
|
typedef void (*MMFunc)(const void* opaque, RowPtrsBF, IndexRange, IndexRange,
|
||||||
|
StridedViewBF, size_t);
|
||||||
|
|
||||||
|
class MMOptions {
|
||||||
|
// Same technique as in `hwy::ThreadPool` and C++23 `std::function_ref`:
|
||||||
|
// type-erasure without allocation.
|
||||||
|
template <class Closure>
|
||||||
|
static void CallClosure(const void* opaque, RowPtrsBF C1, IndexRange range_r,
|
||||||
|
IndexRange range_c, StridedViewBF C2, size_t worker) {
|
||||||
|
(*reinterpret_cast<const Closure*>(opaque))(C1, range_r, range_c, C2,
|
||||||
|
worker);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
// `closure` must remain alive until the end of (Two)MatMul.
|
||||||
|
template <class Closure>
|
||||||
|
void SetFunc(const Closure& closure) {
|
||||||
|
func = static_cast<MMFunc>(&CallClosure<Closure>);
|
||||||
|
opaque = &closure;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MaybeCallFunc(RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
||||||
|
StridedViewBF C2, size_t worker) const {
|
||||||
|
if (func != nullptr) {
|
||||||
|
func(opaque, C1, range_r, range_c, C2, worker);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MMFunc func = nullptr; // called if non-null and `TC` is BF16.
|
||||||
|
const void* opaque = nullptr;
|
||||||
|
|
||||||
struct MMOptions {
|
|
||||||
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||||
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
||||||
|
|
||||||
MMFused fused; // called if non-null and `TC` is BF16.
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
||||||
// register pressure compared to individual values/references. Also used for
|
// register pressure compared to individual values/references. Also used for
|
||||||
// passing through `DispatchOrder`.
|
// passing through `DispatchOrder`.
|
||||||
struct MMArgs {
|
struct MMArgs {
|
||||||
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale,
|
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, float scale_A,
|
||||||
const float* HWY_RESTRICT add, MMOptions options,
|
const float* HWY_RESTRICT add, MMOptions options,
|
||||||
const MMAutoTune<MMConfig>& autotune, const MMConfig& config)
|
const MMAutoTune<MMConfig>& autotune, const MMConfig& config)
|
||||||
: env(env),
|
: env(env),
|
||||||
line_bytes(env.ctx.cache_info.LineBytes()),
|
line_bytes(env.ctx.cache_info.LineBytes()),
|
||||||
|
|
||||||
range_n(0, N),
|
range_n(0, N),
|
||||||
scale(scale),
|
scale_A(scale_A),
|
||||||
add(add),
|
add(add),
|
||||||
options(options),
|
options(options),
|
||||||
|
|
||||||
|
|
@ -728,7 +775,8 @@ struct MMArgs {
|
||||||
|
|
||||||
// MatMul arguments:
|
// MatMul arguments:
|
||||||
const IndexRange range_n; // entire N
|
const IndexRange range_n; // entire N
|
||||||
const double scale;
|
// There can be two B, so do not yet multiply together the A and B scales.
|
||||||
|
const float scale_A;
|
||||||
const float* HWY_RESTRICT add;
|
const float* HWY_RESTRICT add;
|
||||||
const MMOptions options;
|
const MMOptions options;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,14 @@ namespace HWY_NAMESPACE {
|
||||||
// included from matmul_static_*.cc.
|
// included from matmul_static_*.cc.
|
||||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT
|
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DEFINE_ONE, GEMMA_MATMUL_TB) // NOLINT
|
||||||
|
|
||||||
|
HWY_MAYBE_UNUSED void TwoMatMulStatic(const MatPtrT<BF16>& A, // NOLINT
|
||||||
|
const MatPtrT<GEMMA_MATMUL_TB>& B1,
|
||||||
|
const MatPtrT<GEMMA_MATMUL_TB>& B2,
|
||||||
|
MatMulEnv& env, MatPtrT<BF16>& C,
|
||||||
|
MMOptions options) {
|
||||||
|
TwoMatMul(A, B1, B2, env, C, options);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
HWY_AFTER_NAMESPACE();
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
|
||||||
|
|
@ -37,13 +37,19 @@
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
const float* HWY_RESTRICT add, MatMulEnv& env, \
|
||||||
MatPtrT<TC>& C, MMOptions options);
|
MatPtrT<TC>& C, MMOptions options);
|
||||||
|
|
||||||
|
#define GEMMA_MATMUL_FOR_B(TB) \
|
||||||
|
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, TB) \
|
||||||
|
void TwoMatMulStatic(const MatPtrT<BF16>& A, const MatPtrT<TB>& B1, \
|
||||||
|
const MatPtrT<TB>& B2, MatMulEnv& env, \
|
||||||
|
MatPtrT<BF16>& C, MMOptions options);
|
||||||
|
|
||||||
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
// Passed to HWY_VISIT_TARGETS; declares all overloads for all targets.
|
||||||
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
#define GEMMA_MATMUL_DECL(TARGET, NAMESPACE) \
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, BF16) \
|
GEMMA_MATMUL_FOR_B(BF16) \
|
||||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, float) \
|
GEMMA_MATMUL_FOR_B(float) \
|
||||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, NuqStream) \
|
GEMMA_MATMUL_FOR_B(NuqStream) \
|
||||||
GEMMA_MATMUL_FOREACH_AC(GEMMA_MATMUL_DECL_ONE, SfpStream) \
|
GEMMA_MATMUL_FOR_B(SfpStream) \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
|
@ -246,7 +248,9 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
|
MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
||||||
|
MatStorageT<TC> C2("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
|
||||||
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
C.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
|
C2.AllocateAndAttachRowPtrs(env.row_ptrs);
|
||||||
|
|
||||||
MatStorageT<float> add_storage =
|
MatStorageT<float> add_storage =
|
||||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
|
add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
|
||||||
|
|
@ -262,7 +266,48 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
for (size_t rep = 0; rep < 16; ++rep) {
|
for (size_t rep = 0; rep < 16; ++rep) {
|
||||||
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
|
MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options);
|
||||||
AssertClose(A, BT, C_slow, C, env, line);
|
AssertClose(A, BT, C_slow, C, env, line);
|
||||||
if (per_key->autotune.Best()) break;
|
// Check before TwoMatMulStatic(), which can invalidate per_key.
|
||||||
|
const bool autotune_done = !!per_key->autotune.Best();
|
||||||
|
|
||||||
|
// Ensure the tiled view returns the same result as C.
|
||||||
|
if constexpr (IsBF16<TA>() && IsBF16<TC>()) {
|
||||||
|
// The total view area should match the entire C matrix.
|
||||||
|
std::atomic<size_t> total_view_area = 0;
|
||||||
|
|
||||||
|
const auto fused = [&](RowPtrsBF C2_rows, IndexRange range_r,
|
||||||
|
IndexRange range_c, StridedViewBF C2_view,
|
||||||
|
size_t worker) {
|
||||||
|
total_view_area.fetch_add(range_r.Num() * range_c.Num());
|
||||||
|
HWY_ASSERT(range_c.Num() <= C2_view.Cols());
|
||||||
|
HWY_ASSERT(worker < env.ctx.pools.MaxWorkers());
|
||||||
|
for (size_t ir = 0; ir < range_r.Num(); ++ir) {
|
||||||
|
const size_t r = range_r.begin() + ir;
|
||||||
|
for (size_t ic = 0; ic < range_c.Num(); ++ic) {
|
||||||
|
const size_t c = range_c.begin() + ic;
|
||||||
|
const float expected =
|
||||||
|
hwy::ConvertScalarTo<float>(C2_rows.Row(r)[c]);
|
||||||
|
const float actual =
|
||||||
|
hwy::ConvertScalarTo<float>(C2_view.Row(ir)[ic]);
|
||||||
|
const float L1 = hwy::ScalarAbs(actual - expected);
|
||||||
|
if (L1 > 1E-6f) {
|
||||||
|
HWY_ABORT("%zu: ir %zu ic %zu L1 %f expected %f actual %f.",
|
||||||
|
worker, ir, ic, L1, expected, actual);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
options.SetFunc(fused);
|
||||||
|
TwoMatMulStatic(A, BT, BT, env, C2, options);
|
||||||
|
HWY_ASSERT_EQ(C.Extents().Area(), total_view_area.load());
|
||||||
|
options.func = nullptr; // reset for next call
|
||||||
|
|
||||||
|
// TwoMatMulStatic() does not support adding a bias vector.
|
||||||
|
if (!add) {
|
||||||
|
AssertClose(A, BT, C, C2, env, line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (autotune_done) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,14 @@ MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline void CallTwoMatMul(const MatPtrT<BF16>& A, const MatPtr& B1,
|
||||||
|
const MatPtr& B2, MatMulEnv& env,
|
||||||
|
MatPtrT<BF16>& C, const MMOptions& options) {
|
||||||
|
return CallUpcastedSame(&B1, &B2, [&](const auto* B1_t, const auto* B2_t) {
|
||||||
|
return TwoMatMulStatic(A, *B1_t, *B2_t, env, C, options);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
|
HWY_INLINE double PackTokenAndProb(int32_t token, float prob) {
|
||||||
// casting prob from float to double just makes some changes to the
|
// casting prob from float to double just makes some changes to the
|
||||||
// exponent bias and pads zeros in the mantissa.
|
// exponent bias and pads zeros in the mantissa.
|
||||||
|
|
|
||||||
13
util/mat.h
13
util/mat.h
|
|
@ -40,10 +40,11 @@ class RowPtrs {
|
||||||
public:
|
public:
|
||||||
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
|
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
|
||||||
|
|
||||||
RowPtrs View(size_t r, size_t c) {
|
// Extra argument is for compatibility with `StridedView`.
|
||||||
|
RowPtrs View(size_t r, size_t c, size_t /*cols*/) {
|
||||||
RowPtrs<T> view(row_ptrs_);
|
RowPtrs<T> view(row_ptrs_);
|
||||||
view.r0_ = static_cast<uint32_t>(r);
|
view.r0_ = static_cast<uint32_t>(r0_ + r);
|
||||||
view.c0_ = static_cast<uint32_t>(c);
|
view.c0_ = static_cast<uint32_t>(c0_ + c);
|
||||||
return view;
|
return view;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -531,7 +532,11 @@ class StridedView {
|
||||||
: row0_(row0),
|
: row0_(row0),
|
||||||
cols_(static_cast<uint32_t>(cols)),
|
cols_(static_cast<uint32_t>(cols)),
|
||||||
stride_(static_cast<uint32_t>(stride)) {
|
stride_(static_cast<uint32_t>(stride)) {
|
||||||
HWY_DASSERT(stride >= cols);
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
if (stride < cols) {
|
||||||
|
HWY_ABORT("stride %zu < cols %zu", stride, cols);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue