mirror of https://github.com/google/gemma.cpp.git
Matmul refactoring towards fusion
MMLoops: move dispatch code out, use overloads split build target into matmul_env (for MatMulEnv/MMOptions) weights: no longer call BindB Fix potential out of bounds in gemma_batch_bench PiperOrigin-RevId: 804895985
This commit is contained in:
parent
34ceee6c30
commit
461a9c7d1b
|
|
@ -46,6 +46,7 @@ jobs:
|
||||||
-D CMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
-D CMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||||
-D CMAKE_C_COMPILER_LAUNCHER=ccache
|
-D CMAKE_C_COMPILER_LAUNCHER=ccache
|
||||||
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
-DCMAKE_POLICY_VERSION_MINIMUM=3.5
|
||||||
|
|
||||||
- name: Build
|
- name: Build
|
||||||
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4
|
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4
|
||||||
|
|
|
||||||
41
BUILD.bazel
41
BUILD.bazel
|
|
@ -238,7 +238,6 @@ cc_library(
|
||||||
":configs",
|
":configs",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
|
||||||
":model_store",
|
":model_store",
|
||||||
":tensor_info",
|
":tensor_info",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
|
@ -271,14 +270,33 @@ test_suite(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "matmul",
|
name = "matmul_env",
|
||||||
srcs = ["ops/matmul.cc"],
|
srcs = ["ops/matmul.cc"],
|
||||||
hdrs = ["ops/matmul.h"],
|
hdrs = ["ops/matmul.h"],
|
||||||
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":basics",
|
||||||
|
":configs",
|
||||||
|
":mat",
|
||||||
|
":threading",
|
||||||
|
":threading_context",
|
||||||
|
"@highway//:bit_set",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:nanobenchmark",
|
||||||
|
"@highway//:profiler",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "matmul",
|
||||||
|
# allow depending only on this target, without also matmul_env.
|
||||||
|
hdrs = ["ops/matmul.h"],
|
||||||
textual_hdrs = ["ops/matmul-inl.h"],
|
textual_hdrs = ["ops/matmul-inl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
":mat",
|
":mat",
|
||||||
|
":matmul_env",
|
||||||
":threading",
|
":threading",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
|
@ -310,6 +328,7 @@ cc_library(
|
||||||
":basics",
|
":basics",
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
":matmul",
|
||||||
|
":matmul_env",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:types",
|
"//compression:types",
|
||||||
|
|
@ -333,11 +352,12 @@ cc_library(
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
":matmul_env", # MMOptions
|
||||||
":matmul_static",
|
":matmul_static",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:algo",
|
"@highway//:algo",
|
||||||
|
"@highway//:bit_set",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:math",
|
"@highway//:math",
|
||||||
"@highway//:matvec",
|
"@highway//:matvec",
|
||||||
|
|
@ -434,7 +454,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":basics",
|
":basics",
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
":matmul_env",
|
||||||
":matmul_static",
|
":matmul_static",
|
||||||
":ops",
|
":ops",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
|
|
@ -462,7 +482,8 @@ cc_test(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":basics",
|
":basics",
|
||||||
":matmul",
|
":matmul_env",
|
||||||
|
":matmul_static",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
|
@ -495,7 +516,6 @@ cc_library(
|
||||||
":args",
|
":args",
|
||||||
":basics",
|
":basics",
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
|
||||||
"//io",
|
"//io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
|
|
@ -523,13 +543,12 @@ cc_library(
|
||||||
"gemma/gemma-inl.h",
|
"gemma/gemma-inl.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
|
||||||
":basics",
|
":basics",
|
||||||
":configs",
|
":configs",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
":matmul_env",
|
||||||
":model_store",
|
":model_store",
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
":threading",
|
||||||
|
|
@ -569,7 +588,7 @@ cc_library(
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":matmul",
|
":matmul_env",
|
||||||
":ops",
|
":ops",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
|
|
@ -600,7 +619,7 @@ cc_library(
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":matmul",
|
":matmul_env",
|
||||||
":threading",
|
":threading",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
|
|
@ -661,7 +680,7 @@ cc_binary(
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":matmul",
|
":matmul_env",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
"//compression:types",
|
"//compression:types",
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,8 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||||
if (qpos == questions.size()) qpos = 0;
|
if (qpos == questions.size()) qpos = 0;
|
||||||
}
|
}
|
||||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||||
for (size_t i = 0; i < hwy::Unpredictable1() * 3; ++i) {
|
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
|
||||||
|
++i) {
|
||||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//:gemma_args",
|
"//:gemma_args",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
"//:matmul",
|
"//:matmul_env",
|
||||||
"//:threading_context",
|
"//:threading_context",
|
||||||
"//:tokenizer",
|
"//:tokenizer",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
#include "ops/matmul.h" // MMStorage::kMax*
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,6 @@
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
#include "gemma/model_store.h"
|
#include "gemma/model_store.h"
|
||||||
#include "io/blob_store.h"
|
#include "io/blob_store.h"
|
||||||
#include "ops/matmul.h" // MMParallel
|
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -338,7 +337,6 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
||||||
|
|
||||||
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
|
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
|
||||||
tensor.padding);
|
tensor.padding);
|
||||||
BindB(ctx, *tensor.mat, tensor.mat->ElementBytes());
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
541
ops/matmul-inl.h
541
ops/matmul-inl.h
|
|
@ -152,10 +152,10 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
// four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is
|
// four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is
|
||||||
// `MMSetC`, the vectors are written as-is (first call, or small K).
|
// `MMSetC`, the vectors are written as-is (first call, or small K).
|
||||||
// Otherwise, they are partial sums and are accumulated into C.
|
// Otherwise, they are partial sums and are accumulated into C.
|
||||||
template <class D4, class V4 = hn::Vec<D4>, class Tag, typename TC>
|
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CRows>
|
||||||
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag,
|
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag,
|
||||||
const size_t row_c, const size_t col_c,
|
const size_t row_c, const size_t col_c,
|
||||||
const MMArgs& args, RowPtrs<TC> C_rows) const {
|
const MMArgs& args, CRows C_rows) const {
|
||||||
const V4 vscale = hn::Set(d4, args.scale);
|
const V4 vscale = hn::Set(d4, args.scale);
|
||||||
HWY_ALIGN static constexpr float kZero[4] = {};
|
HWY_ALIGN static constexpr float kZero[4] = {};
|
||||||
const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero);
|
const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero);
|
||||||
|
|
@ -219,18 +219,24 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
}
|
}
|
||||||
}; // MMStoreHorizontalSumsIntoC
|
}; // MMStoreHorizontalSumsIntoC
|
||||||
|
|
||||||
// Stateless, wraps member functions.
|
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
|
||||||
class MMKernel {
|
class MMKernel {
|
||||||
|
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||||
|
// allocation avoids passing a worker index.
|
||||||
|
static constexpr size_t B_stride_max =
|
||||||
|
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
||||||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
||||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||||
template <class Tag, typename TC>
|
// Called by B3A2C0 and by callers that hoist `A_view`.
|
||||||
|
template <class Tag, class CRows>
|
||||||
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
||||||
const StridedViewBF B_view, size_t mr,
|
const StridedViewBF B_view, size_t mr,
|
||||||
const IndexRange& range_mc, const size_t row_b,
|
const IndexRange& range_mc, const size_t row_b,
|
||||||
size_t kc, Tag tag, const MMArgs& args,
|
size_t kc, Tag tag, const MMArgs& args,
|
||||||
RowPtrs<TC> C_rows) {
|
CRows C_rows) {
|
||||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||||
const size_t row0 = range_mc.begin();
|
const size_t row0 = range_mc.begin();
|
||||||
const size_t mc = range_mc.Num();
|
const size_t mc = range_mc.Num();
|
||||||
|
|
@ -280,6 +286,90 @@ class MMKernel {
|
||||||
HWY_DASSERT(imc == mc);
|
HWY_DASSERT(imc == mc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static constexpr size_t B_storage_max = kNR * B_stride_max;
|
||||||
|
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
template <typename T>
|
||||||
|
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
||||||
|
size_t cols) {
|
||||||
|
HWY_DASSERT(c < AB.Cols());
|
||||||
|
HWY_DASSERT(cols <= AB.Cols() - c);
|
||||||
|
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
||||||
|
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
||||||
|
// thanks to its large table lookups, and less so on other targets.
|
||||||
|
template <typename TB>
|
||||||
|
static StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||||
|
const IndexRange& range_kc,
|
||||||
|
const StridedViewBF B_view) {
|
||||||
|
const hn::ScalableTag<BF16> dbf;
|
||||||
|
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
||||||
|
// Neither A nor B require padding because `LoopKC` handles remainders.
|
||||||
|
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||||
|
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||||
|
}
|
||||||
|
|
||||||
|
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||||
|
|
||||||
|
const size_t kc = range_kc.Num();
|
||||||
|
const size_t col0 = range_kc.begin();
|
||||||
|
|
||||||
|
for (size_t r = 0; r < kNR; ++r) {
|
||||||
|
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
||||||
|
BF16* HWY_RESTRICT to = B_view.Row(r);
|
||||||
|
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
||||||
|
// Verify that we zero-padded.
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
||||||
|
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return B_view;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
|
||||||
|
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
|
||||||
|
// `ForeachKC` and when there is only a single KC task.
|
||||||
|
template <typename TB, typename Tag, class CRows>
|
||||||
|
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
|
const MMArgs& args, const IndexRange& range_mc,
|
||||||
|
const IndexRange& range_kc, const IndexRange& range_nc,
|
||||||
|
size_t mr, Tag out_tag, CRows C_rows) {
|
||||||
|
HWY_ALIGN BF16 B_storage[B_storage_max];
|
||||||
|
|
||||||
|
const size_t kc = range_kc.Num();
|
||||||
|
const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc);
|
||||||
|
|
||||||
|
const size_t B_stride =
|
||||||
|
Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes);
|
||||||
|
const StridedViewBF B_storage_view(B_storage, kc, B_stride);
|
||||||
|
|
||||||
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
|
row_b += kNR) {
|
||||||
|
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||||
|
A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TB, class CRows>
|
||||||
|
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
|
const MMArgs& args, const IndexRange& range_mc,
|
||||||
|
const IndexRangePartition& ranges_kc,
|
||||||
|
const IndexRange& range_nc, size_t mr, CRows C_rows) {
|
||||||
|
// Peel off the first iteration of the kc loop: avoid zero-initializing `C`
|
||||||
|
// by writing directly into it, and later accumulating into it.
|
||||||
|
ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
|
||||||
|
B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMSetC(), C_rows);
|
||||||
|
});
|
||||||
|
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
|
||||||
|
B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMAddC(), C_rows);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Element-wise multiplies a vector from one row of A with `kNR` vectors,
|
// Element-wise multiplies a vector from one row of A with `kNR` vectors,
|
||||||
// each from a row of transposed B, and adds them to `kNR` fp32 `Cc`
|
// each from a row of transposed B, and adds them to `kNR` fp32 `Cc`
|
||||||
|
|
@ -372,11 +462,11 @@ class MMKernel {
|
||||||
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
||||||
// Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`.
|
// Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`.
|
||||||
// `A` and `B` are always BF16, `C` can be F32 or BF16.
|
// `A` and `B` are always BF16, `C` can be F32 or BF16.
|
||||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
template <size_t kRowsAC, /*deduced:*/ class Tag, class CRows>
|
||||||
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
||||||
const StridedViewBF B_view, size_t row_ac,
|
const StridedViewBF B_view, size_t row_ac,
|
||||||
size_t imc, size_t col_c, size_t kc, Tag tag,
|
size_t imc, size_t col_c, size_t kc, Tag tag,
|
||||||
const MMArgs& args, RowPtrs<TC> C_rows) {
|
const MMArgs& args, CRows C_rows) {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
using VBF = hn::Vec<decltype(dbf)>;
|
using VBF = hn::Vec<decltype(dbf)>;
|
||||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
@ -601,7 +691,7 @@ class MMImpl {
|
||||||
size_t vector_bytes,
|
size_t vector_bytes,
|
||||||
MatMulEnv::PerCluster& per_cluster) {
|
MatMulEnv::PerCluster& per_cluster) {
|
||||||
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
|
||||||
intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys);
|
intptr_t index = IndexOfKey(key, per_cluster.keys);
|
||||||
// First time we see this shape/key.
|
// First time we see this shape/key.
|
||||||
if (HWY_UNLIKELY(index < 0)) {
|
if (HWY_UNLIKELY(index < 0)) {
|
||||||
per_cluster.keys.Append(key, vector_bytes);
|
per_cluster.keys.Append(key, vector_bytes);
|
||||||
|
|
@ -614,9 +704,9 @@ class MMImpl {
|
||||||
return per_cluster.per_key[index];
|
return per_cluster.per_key[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
static void NotifyAutotuneResult(size_t M, size_t K, size_t N, double t0,
|
static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N,
|
||||||
const MMConfig& cfg, MatMulEnv& env,
|
double t0, MMAutoTune<MMConfig>& tuner,
|
||||||
MMAutoTune<MMConfig>& tuner) {
|
const MMConfig& cfg) {
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
|
|
@ -653,39 +743,16 @@ class MMImpl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t Worker(const MMArgs& args) {
|
static size_t Worker(const MatMulEnv& env, size_t cluster_idx) {
|
||||||
return args.options.cluster_idx *
|
return cluster_idx * env.ctx.pools.MaxWorkersPerCluster();
|
||||||
args.env->ctx.pools.MaxWorkersPerCluster();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
||||||
template <typename T>
|
|
||||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
|
||||||
size_t cols) {
|
|
||||||
HWY_DASSERT(c < AB.Cols());
|
|
||||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
|
||||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class Func>
|
|
||||||
static void DispatchParallelism(ParallelismStrategy parallelism,
|
|
||||||
const Func& func) {
|
|
||||||
switch (parallelism) {
|
|
||||||
case ParallelismStrategy::kHierarchical:
|
|
||||||
return func(MMParallelHierarchical());
|
|
||||||
case ParallelismStrategy::kNone:
|
|
||||||
return func(MMParallelNone());
|
|
||||||
case ParallelismStrategy::kWithinCluster:
|
|
||||||
return func(MMParallelWithinCluster());
|
|
||||||
default:
|
|
||||||
HWY_UNREACHABLE;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||||
static HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
static HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
||||||
const StridedViewBF A_view,
|
const StridedViewBF A_view,
|
||||||
MMParA par_a, const MMArgs& args) {
|
MMAutoTune<MMParA>& autotune,
|
||||||
|
MMParA par_a, const MatMulEnv& env,
|
||||||
|
const MMOptions& options) {
|
||||||
const IndexRange all_M(0, A.Rows());
|
const IndexRange all_M(0, A.Rows());
|
||||||
const IndexRange all_K(0, A.Cols());
|
const IndexRange all_K(0, A.Cols());
|
||||||
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
||||||
|
|
@ -693,13 +760,13 @@ class MMImpl {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
const size_t NBF = hn::Lanes(dbf);
|
const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
||||||
static const auto zone = args.env->ctx.profiler.AddZone("MM.DecompressA");
|
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
|
||||||
|
|
||||||
const auto do_range =
|
const auto do_range =
|
||||||
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
|
||||||
HWY_ATTR {
|
HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args);
|
mm_zone.MaybeEnter(worker, zone, env, &autotune);
|
||||||
|
|
||||||
const size_t col0 = range_K.begin();
|
const size_t col0 = range_K.begin();
|
||||||
const size_t cols = range_K.Num();
|
const size_t cols = range_K.Num();
|
||||||
|
|
@ -722,7 +789,7 @@ class MMImpl {
|
||||||
|
|
||||||
switch (par_a) {
|
switch (par_a) {
|
||||||
case MMParA::kNone:
|
case MMParA::kNone:
|
||||||
do_range(all_M, all_K, MMImpl::Worker(args));
|
do_range(all_M, all_K, Worker(env, options.cluster_idx));
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case MMParA::kK1:
|
case MMParA::kK1:
|
||||||
|
|
@ -732,27 +799,26 @@ class MMImpl {
|
||||||
// At least one vector, otherwise DecompressAndZeroPad will add
|
// At least one vector, otherwise DecompressAndZeroPad will add
|
||||||
// padding, which might overwrite neighboring tasks. Also a whole cache
|
// padding, which might overwrite neighboring tasks. Also a whole cache
|
||||||
// line to avoid false sharing.
|
// line to avoid false sharing.
|
||||||
const size_t multiple_K = HWY_MAX(NBF, args.line_bytes / sizeof(BF16));
|
const size_t multiple_K =
|
||||||
|
HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16));
|
||||||
|
|
||||||
DispatchParallelism(
|
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||||
args.options.parallelism, [&](const auto& parallel) {
|
parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks,
|
||||||
parallel.ForN(args.env->ctx, all_K, multiple_K, inner_tasks,
|
options.cluster_idx,
|
||||||
args.options.cluster_idx,
|
[&](const IndexRange& range_K, size_t worker) {
|
||||||
[&](const IndexRange& range_K, size_t worker) {
|
do_range(all_M, range_K, worker);
|
||||||
do_range(all_M, range_K, worker);
|
});
|
||||||
});
|
});
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case MMParA::kM:
|
case MMParA::kM:
|
||||||
DispatchParallelism(
|
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
|
||||||
args.options.parallelism, [&](const auto& parallel) {
|
parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx,
|
||||||
parallel.ForRangeMC(
|
[&](size_t row_a, size_t worker) {
|
||||||
args.env->ctx, all_M, args.options.cluster_idx,
|
do_range(IndexRange(row_a, row_a + 1), all_K,
|
||||||
[&](size_t row_a, size_t worker) {
|
worker);
|
||||||
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
});
|
||||||
});
|
});
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -760,11 +826,11 @@ class MMImpl {
|
||||||
// Autotuning wrapper for `DoDecompressA`.
|
// Autotuning wrapper for `DoDecompressA`.
|
||||||
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
||||||
const StridedViewBF A_view,
|
const StridedViewBF A_view,
|
||||||
const MMArgs& args) {
|
MMAutoTune<MMParA>& autotune,
|
||||||
MMAutoTune<MMParA>& autotune = args.per_key->autotune_par_a;
|
const MatMulEnv& env,
|
||||||
|
const MMOptions& options) {
|
||||||
if (HWY_LIKELY(autotune.Best())) {
|
if (HWY_LIKELY(autotune.Best())) {
|
||||||
return DoDecompressA(A, A_view, *autotune.Best(), args);
|
return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
// First call: generate candidates.
|
// First call: generate candidates.
|
||||||
|
|
@ -777,11 +843,11 @@ class MMImpl {
|
||||||
|
|
||||||
const MMParA& par_a = autotune.NextConfig();
|
const MMParA& par_a = autotune.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
DoDecompressA(A, A_view, par_a, args);
|
DoDecompressA(A, A_view, autotune, par_a, env, options);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
args.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||||
if (HWY_UNLIKELY(args.env->print_measurement && autotune.ShouldPrint())) {
|
if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) {
|
||||||
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
|
||||||
static_cast<double>(min_elapsed) /
|
static_cast<double>(min_elapsed) /
|
||||||
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
||||||
|
|
@ -790,299 +856,148 @@ class MMImpl {
|
||||||
|
|
||||||
template <typename TA>
|
template <typename TA>
|
||||||
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
|
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
|
||||||
const MMArgs& args) {
|
MMAutoTune<MMParA>& autotune,
|
||||||
|
const MatMulEnv& env,
|
||||||
|
MMOptions options) {
|
||||||
if constexpr (IsBF16<TA>()) {
|
if constexpr (IsBF16<TA>()) {
|
||||||
// We can use a view, regardless of columns/padding, because `LoopKC`
|
// We can use a view, regardless of columns/padding, because `LoopKC`
|
||||||
// supports non-vector multiples.
|
// supports non-vector multiples.
|
||||||
return View(A, 0, 0, A.Cols());
|
return MMKernel::View(A, 0, 0, A.Cols());
|
||||||
} else {
|
} else {
|
||||||
// Always decompress. To reduce code size/compile time, we no longer
|
// Always decompress. To reduce code size/compile time, we no longer
|
||||||
// support a separate F32 kernel; most A are already BF16. We also only
|
// support a separate F32 kernel; most A are already BF16. We also only
|
||||||
// have a single MMStorage.
|
// have a single MMStorage.
|
||||||
HWY_ASSERT(args.options.cluster_idx == 0);
|
HWY_ASSERT(options.cluster_idx == 0);
|
||||||
const StridedViewBF A_view = args.env->storage.A(A.Extents());
|
const StridedViewBF A_view = env.storage.A(A.Extents());
|
||||||
DecompressA(A, A_view, args);
|
DecompressA(A, A_view, autotune, env, options);
|
||||||
return A_view;
|
return A_view;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Contains several variants of the outer M/N/K loops, and calls `A2C0` which
|
// Defines several variants of the outer M/N/K loops (see `MMOrder`).
|
||||||
// loops over the inner KC and MC. Member variables avoid long argument lists.
|
class MMLoops {
|
||||||
class MMState {
|
|
||||||
public:
|
public:
|
||||||
MMState(size_t M, size_t K, size_t N, const MMArgs& args,
|
|
||||||
const MMConfig& config)
|
|
||||||
: args_(args),
|
|
||||||
range_n_(0, N),
|
|
||||||
mr_(config.MR()),
|
|
||||||
ranges_mc_(config.RangesOfMC(M)),
|
|
||||||
ranges_kc_(config.RangesOfKC(K)),
|
|
||||||
ranges_nc_(config.RangesOfNC(N)),
|
|
||||||
order_(config.Order()),
|
|
||||||
inner_tasks_(config.InnerTasks()) {}
|
|
||||||
|
|
||||||
// Called from `MatMul` from two places: either with the next autotune config,
|
// Called from `MatMul` from two places: either with the next autotune config,
|
||||||
// or with the best config.
|
// or with the best config.
|
||||||
template <typename TB, typename TC>
|
template <typename TB, typename TC>
|
||||||
HWY_NOINLINE void DispatchParallelism(const StridedViewBF A,
|
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
const MatPtrT<TB>& B,
|
RowPtrs<TC> C_rows, const MMArgs& args) {
|
||||||
RowPtrs<TC> C_rows) const {
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
|
||||||
static const auto zone =
|
PROFILER_ZONE3(args.env.ctx.profiler,
|
||||||
args_.env->ctx.profiler.AddZone("MM.DispatchParallelism");
|
MMImpl::Worker(args.env, args.options.cluster_idx), zone);
|
||||||
PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone);
|
|
||||||
|
|
||||||
MMImpl::DispatchParallelism(
|
DispatchParallelism(
|
||||||
args_.options.parallelism,
|
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
|
||||||
[&](const auto& parallel) { DispatchOrder(parallel, A, B, C_rows); });
|
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
|
||||||
|
Loop(order, parallel, A, B, C_rows, args);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
|
||||||
// allocation avoids passing a worker index.
|
|
||||||
static constexpr size_t B_stride_max_ =
|
|
||||||
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
|
|
||||||
static constexpr size_t B_storage_max_ = kNR * B_stride_max_;
|
|
||||||
|
|
||||||
// Granularity of `ForN`. B rows produce C columns, so we
|
// Granularity of `ForN`. B rows produce C columns, so we
|
||||||
// want a multiple of the line size to prevent false sharing.
|
// want a multiple of the line size to prevent false sharing.
|
||||||
size_t MultipleN(size_t sizeof_TC) const {
|
static size_t MultipleN(size_t sizeof_TC, size_t line_bytes) {
|
||||||
return HWY_MAX(kNR, args_.line_bytes / sizeof_TC);
|
return HWY_MAX(kNR, line_bytes / sizeof_TC);
|
||||||
}
|
|
||||||
|
|
||||||
// B is decompressed several call layers lower, but not all member functions
|
|
||||||
// depend on `TB`, so pass it as an argument instead of templating the class.
|
|
||||||
template <typename TB, typename TC, class ParallelT>
|
|
||||||
HWY_NOINLINE void DispatchOrder(const ParallelT& parallel_policy,
|
|
||||||
const StridedViewBF A, const MatPtrT<TB>& B,
|
|
||||||
RowPtrs<TC> C_rows) const {
|
|
||||||
switch (order_) {
|
|
||||||
case MMOrder::kNT:
|
|
||||||
return DoNT(parallel_policy, A, B, C_rows);
|
|
||||||
case MMOrder::kNT_K:
|
|
||||||
return DoNT_K(parallel_policy, A, B, C_rows);
|
|
||||||
case MMOrder::kNT_MT:
|
|
||||||
return DoNT_MT(parallel_policy, A, B, C_rows);
|
|
||||||
case MMOrder::kNT_MT_K:
|
|
||||||
return DoNT_MT_K(parallel_policy, A, B, C_rows);
|
|
||||||
default:
|
|
||||||
HWY_UNREACHABLE;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||||
template <typename TB, typename TC, class ParallelT>
|
template <typename TB, typename TC, class Parallel>
|
||||||
HWY_INLINE void DoNT(ParallelT parallel, const StridedViewBF A,
|
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
RowPtrs<TC> C_rows, const MMArgs& args) {
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
const IndexRange& range_M = ranges_mc_.Range(0);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
const IndexRange& range_M = args.ranges_mc.Range(0);
|
||||||
|
const IndexRange& range_K = args.ranges_kc.Range(0);
|
||||||
const size_t K = range_K.Num();
|
const size_t K = range_K.Num();
|
||||||
const StridedViewBF A_view = A.View(range_M.begin(), 0, K);
|
const StridedViewBF A_view = A.View(range_M.begin(), 0, K);
|
||||||
const size_t B_stride =
|
const size_t B_stride =
|
||||||
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
|
Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes);
|
||||||
|
|
||||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
// Similar to `B3A2C0`, but here we hoisted `A_view`.
|
||||||
parallel.ForN(
|
parallel.ForN(
|
||||||
args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_,
|
args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes),
|
||||||
args_.options.cluster_idx,
|
args.inner_tasks, args.options.cluster_idx,
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
|
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS
|
||||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
StridedViewBF B_view =
|
StridedViewBF B_view =
|
||||||
DecompressB(B, row_b, range_K, B_storage_view);
|
MMKernel::DecompressB(B, row_b, range_K, B_storage_view);
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(),
|
||||||
args_, C_rows);
|
args, C_rows);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
||||||
template <typename TB, typename TC, class ParallelT>
|
template <typename TB, typename TC, class Parallel>
|
||||||
HWY_INLINE void DoNT_K(ParallelT parallel, const StridedViewBF A,
|
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
RowPtrs<TC> C_rows, const MMArgs& args) {
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
|
||||||
const IndexRange& range_mc = ranges_mc_.Range(0);
|
HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
|
||||||
|
const IndexRange& range_mc = args.ranges_mc.Range(0);
|
||||||
|
|
||||||
// Loop over NC/MC/KC, called from the outer loops over K/N.
|
parallel.ForN(args.env.ctx, args.range_n,
|
||||||
// C++14 generic lambda enables hoisting branches via template
|
MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks,
|
||||||
// argument, while also capturing to avoid long argument lists.
|
args.options.cluster_idx,
|
||||||
const auto loop_nc = [&](BF16* B_storage, const IndexRange& range_kc,
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
const IndexRange& range_nc,
|
MMZone mm_zone;
|
||||||
auto out_tag) HWY_ATTR {
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
const size_t kc = range_kc.Num();
|
MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc,
|
||||||
const StridedViewBF A_view =
|
range_nc, args.mr, C_rows);
|
||||||
A.View(range_mc.begin(), range_kc.begin(), kc);
|
});
|
||||||
const StridedViewBF B_storage_view(
|
|
||||||
B_storage, kc,
|
|
||||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), args_.line_bytes));
|
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
|
||||||
row_b += kNR) {
|
|
||||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
|
||||||
C_rows);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
parallel.ForN(
|
|
||||||
args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_,
|
|
||||||
args_.options.cluster_idx,
|
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
|
||||||
MMZone mm_zone;
|
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
|
||||||
|
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
|
||||||
|
|
||||||
// Peel off the first iteration of the kc loop: avoid
|
|
||||||
// zero-initializing `partial` by writing into it.
|
|
||||||
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
|
|
||||||
loop_nc(B_storage, range_kc, range_nc, MMSetC());
|
|
||||||
});
|
|
||||||
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
|
|
||||||
loop_nc(B_storage, range_kc, range_nc, MMAddC());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_n, single K.
|
// Parallel loops over mc/nc blocks of M/range_n, single K.
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C directly, in parallel.
|
||||||
template <typename TB, typename TC, class ParallelT>
|
template <typename TB, typename TC, class Parallel>
|
||||||
HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A,
|
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
RowPtrs<TC> C_rows, const MMArgs& args) {
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
|
||||||
const size_t K = range_K.Num();
|
const IndexRange& range_K = args.ranges_kc.Range(0);
|
||||||
const size_t B_stride =
|
|
||||||
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
|
|
||||||
|
|
||||||
// Similar to `loop_nc` below except for the profiler zone and `MMSetC`.
|
|
||||||
parallel.ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
|
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
|
MMKernel::B3A2C0(A, B, args, range_mc, range_K, range_nc, args.mr,
|
||||||
const StridedViewBF A_view = A.View(range_mc.begin(), 0, K);
|
MMSetC(), C_rows);
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
|
||||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
|
||||||
row_b += kNR) {
|
|
||||||
const StridedViewBF B_view =
|
|
||||||
DecompressB(B, row_b, range_K, B_storage_view);
|
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
|
||||||
args_, C_rows);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||||
// Accumulates into `mc x nc` sections of `C`.
|
// Accumulates into `mc x nc` sections of `C`.
|
||||||
template <typename TB, typename TC, class ParallelT>
|
template <typename TB, typename TC, class Parallel>
|
||||||
HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A,
|
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const StridedViewBF A, const MatPtrT<TB>& B,
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
RowPtrs<TC> C_rows, const MMArgs& args) {
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
|
||||||
HWY_DASSERT(kc_max <= kMaxKC);
|
|
||||||
const size_t B_stride =
|
|
||||||
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes);
|
|
||||||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
|
||||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
|
||||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `C`.
|
|
||||||
const auto loop_nc = [&](const StridedViewBF B_storage_view,
|
|
||||||
const IndexRange& range_mc,
|
|
||||||
const IndexRange& range_kc,
|
|
||||||
const IndexRange& range_nc,
|
|
||||||
auto out_tag) HWY_ATTR {
|
|
||||||
const size_t kc = range_kc.Num();
|
|
||||||
const StridedViewBF A_view =
|
|
||||||
A.View(range_mc.begin(), range_kc.begin(), kc);
|
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
|
||||||
row_b += kNR) {
|
|
||||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
|
||||||
C_rows);
|
|
||||||
}
|
|
||||||
}; // loop_nc
|
|
||||||
parallel.ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
|
args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
|
||||||
|
MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, range_nc,
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
args.mr, C_rows);
|
||||||
const StridedViewBF B_storage_view(B_storage, kc_max, B_stride);
|
|
||||||
|
|
||||||
// Peel off the first iteration of the kc loop: avoid
|
|
||||||
// zero-initializing `C` by writing into it.
|
|
||||||
ranges_kc_.VisitFirst([&](const IndexRange& range_kc) {
|
|
||||||
loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMSetC());
|
|
||||||
});
|
|
||||||
ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) {
|
|
||||||
loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMAddC());
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}; // MMLoops
|
||||||
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
|
||||||
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
|
|
||||||
// thanks to its large table lookups, and less so on other targets.
|
|
||||||
template <typename TB>
|
|
||||||
HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
|
||||||
const IndexRange& range_kc,
|
|
||||||
const StridedViewBF B_view) const {
|
|
||||||
const hn::ScalableTag<BF16> dbf;
|
|
||||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
|
||||||
|
|
||||||
// Neither A nor B require padding because `LoopKC` handles remainders.
|
|
||||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
|
||||||
return MMImpl::View(B, row_b, range_kc.begin(), range_kc.Num());
|
|
||||||
}
|
|
||||||
|
|
||||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
|
||||||
|
|
||||||
const size_t kc = range_kc.Num();
|
|
||||||
const size_t col0 = range_kc.begin();
|
|
||||||
|
|
||||||
for (size_t r = 0; r < kNR; ++r) {
|
|
||||||
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
|
|
||||||
BF16* HWY_RESTRICT to = B_view.Row(r);
|
|
||||||
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
|
||||||
// Verify that we zero-padded.
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
|
||||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return B_view;
|
|
||||||
}
|
|
||||||
|
|
||||||
const MMArgs args_; // copy for locality
|
|
||||||
|
|
||||||
const IndexRange range_n_;
|
|
||||||
// From MMConfig:
|
|
||||||
const size_t mr_;
|
|
||||||
const IndexRangePartition ranges_mc_;
|
|
||||||
const IndexRangePartition ranges_kc_;
|
|
||||||
const IndexRangePartition ranges_nc_;
|
|
||||||
const MMOrder order_;
|
|
||||||
const size_t inner_tasks_;
|
|
||||||
}; // MMState
|
|
||||||
|
|
||||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||||
//
|
//
|
||||||
|
|
@ -1109,29 +1024,30 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||||
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
||||||
|
const size_t cluster_idx = options.cluster_idx;
|
||||||
|
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
|
||||||
PROFILER_ZONE3(env.ctx.profiler,
|
PROFILER_ZONE3(env.ctx.profiler,
|
||||||
options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(),
|
cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone);
|
||||||
zone);
|
|
||||||
|
|
||||||
HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
|
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
|
||||||
RowPtrs<TC> C_rows =
|
|
||||||
GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]);
|
|
||||||
|
|
||||||
const size_t M = A.Rows();
|
const size_t M = A.Rows();
|
||||||
const size_t K = A.Cols();
|
const size_t K = A.Cols();
|
||||||
const size_t N = B.Rows();
|
const size_t N = B.Rows();
|
||||||
|
|
||||||
const CacheInfo& cache = env.ctx.cache_info;
|
const CacheInfo& cache = env.ctx.cache_info;
|
||||||
MMPerKey& per_key = MMImpl::FindOrAddPerKey(
|
MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(),
|
||||||
M, K, N, cache.VectorBytes(), env.per_cluster[options.cluster_idx]);
|
env.per_cluster[cluster_idx]);
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
|
||||||
|
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
||||||
add, options);
|
const StridedViewBF A_view =
|
||||||
|
MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
|
||||||
|
|
||||||
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
const MMState state(M, K, N, args, *tuner.Best());
|
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
|
add, options, tuner, *tuner.Best());
|
||||||
state.DispatchParallelism(A_view, B, C_rows);
|
MMLoops::Dispatch(A_view, B, C_rows, args);
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1147,14 +1063,13 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
MMCandidates(cache, M, K, N, sizeof(TC), env.print_config));
|
MMCandidates(cache, M, K, N, sizeof(TC), env.print_config));
|
||||||
}
|
}
|
||||||
|
|
||||||
// (Also auto-tunes, hence outside the timed section to prevent interference.)
|
|
||||||
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
|
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
|
const MMArgs args(env, M, K, N, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
|
add, options, tuner, cfg);
|
||||||
|
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMState state(M, K, N, args, cfg);
|
MMLoops::Dispatch(A_view, B, C_rows, args);
|
||||||
state.DispatchParallelism(A_view, B, C_rows);
|
MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg);
|
||||||
MMImpl::NotifyAutotuneResult(M, K, N, t0, cfg, env, tuner);
|
|
||||||
|
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
180
ops/matmul.h
180
ops/matmul.h
|
|
@ -21,7 +21,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <memory> // std::unique_ptr
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
|
|
@ -54,13 +54,58 @@ HWY_INLINE_VAR constexpr size_t kNR = 4;
|
||||||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
|
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
|
||||||
|
|
||||||
|
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink?
|
||||||
|
|
||||||
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||||
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
||||||
|
|
||||||
|
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
||||||
|
// Also used to decompress B, hence non-const.
|
||||||
|
#pragma pack(push, 1) // power of two size
|
||||||
|
template <typename T>
|
||||||
|
class StridedView {
|
||||||
|
public:
|
||||||
|
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||||
|
: row0_(row0),
|
||||||
|
cols_(static_cast<uint32_t>(cols)),
|
||||||
|
stride_(static_cast<uint32_t>(stride)) {
|
||||||
|
HWY_DASSERT(stride >= cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
||||||
|
size_t Cols() const { return static_cast<size_t>(cols_); }
|
||||||
|
|
||||||
|
size_t Stride() const { return static_cast<size_t>(stride_); }
|
||||||
|
void SetStride(size_t stride) {
|
||||||
|
HWY_DASSERT(stride >= Cols());
|
||||||
|
stride_ = stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
StridedView<T> View(size_t r, size_t c, size_t cols) const {
|
||||||
|
HWY_DASSERT(c < Cols());
|
||||||
|
HWY_DASSERT(cols <= Cols() - c);
|
||||||
|
return StridedView<T>(Row(r) + c, cols, stride_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
T* HWY_RESTRICT row0_;
|
||||||
|
uint32_t cols_;
|
||||||
|
uint32_t stride_;
|
||||||
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
|
using StridedViewBF = StridedView<BF16>;
|
||||||
|
using StridedViewD = StridedView<double>;
|
||||||
|
|
||||||
|
using MMFused = std::function<void(StridedViewBF, size_t, size_t)>;
|
||||||
|
|
||||||
struct MMOptions {
|
struct MMOptions {
|
||||||
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
|
||||||
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
|
||||||
|
|
||||||
|
MMFused fused;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
|
||||||
|
|
@ -260,49 +305,26 @@ struct MMParallelHierarchical {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class Func, typename... Args>
|
||||||
|
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func,
|
||||||
|
Args&&... args) {
|
||||||
|
switch (parallelism) {
|
||||||
|
case ParallelismStrategy::kNone:
|
||||||
|
return func(MMParallelNone(), std::forward<Args>(args)...);
|
||||||
|
case ParallelismStrategy::kWithinCluster:
|
||||||
|
return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
|
||||||
|
case ParallelismStrategy::kHierarchical:
|
||||||
|
return func(MMParallelHierarchical(), std::forward<Args>(args)...);
|
||||||
|
default:
|
||||||
|
HWY_UNREACHABLE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
|
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
|
||||||
// C is BF16/float.
|
// C is BF16/float.
|
||||||
void BindC(ThreadingContext& ctx, MatPtr& C);
|
void BindC(ThreadingContext& ctx, MatPtr& C);
|
||||||
|
|
||||||
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
// For A.
|
||||||
// Also used to decompress B, hence non-const.
|
|
||||||
#pragma pack(push, 1) // power of two size
|
|
||||||
template <typename T>
|
|
||||||
class StridedView {
|
|
||||||
public:
|
|
||||||
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
|
||||||
: row0_(row0),
|
|
||||||
cols_(static_cast<uint32_t>(cols)),
|
|
||||||
stride_(static_cast<uint32_t>(stride)) {
|
|
||||||
HWY_DASSERT(stride >= cols);
|
|
||||||
}
|
|
||||||
|
|
||||||
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
|
||||||
size_t Cols() const { return static_cast<size_t>(cols_); }
|
|
||||||
|
|
||||||
size_t Stride() const { return static_cast<size_t>(stride_); }
|
|
||||||
void SetStride(size_t stride) {
|
|
||||||
HWY_DASSERT(stride >= Cols());
|
|
||||||
stride_ = stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
||||||
StridedView<T> View(size_t r, size_t c, size_t cols) const {
|
|
||||||
HWY_DASSERT(c < Cols());
|
|
||||||
HWY_DASSERT(cols <= Cols() - c);
|
|
||||||
return StridedView<T>(Row(r) + c, cols, stride_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
T* HWY_RESTRICT row0_;
|
|
||||||
uint32_t cols_;
|
|
||||||
uint32_t stride_;
|
|
||||||
};
|
|
||||||
#pragma pack(pop)
|
|
||||||
|
|
||||||
using StridedViewBF = StridedView<BF16>;
|
|
||||||
using StridedViewD = StridedView<double>;
|
|
||||||
|
|
||||||
class MMStorage {
|
class MMStorage {
|
||||||
public:
|
public:
|
||||||
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
||||||
|
|
@ -354,6 +376,28 @@ enum class MMOrder : uint8_t {
|
||||||
// no kM* because we expect M (batch size) to be small relative to K and N.
|
// no kM* because we expect M (batch size) to be small relative to K and N.
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Tag types for `DispatchOrder`.
|
||||||
|
struct MMOrderNT_K {};
|
||||||
|
struct MMOrderNT {};
|
||||||
|
struct MMOrderNT_MT_K {};
|
||||||
|
struct MMOrderNT_MT {};
|
||||||
|
|
||||||
|
template <class Func, typename... Args>
|
||||||
|
void DispatchOrder(MMOrder order, const Func& func, Args&&... args) {
|
||||||
|
switch (order) {
|
||||||
|
case MMOrder::kNT_K:
|
||||||
|
return func(MMOrderNT_K(), std::forward<Args>(args)...);
|
||||||
|
case MMOrder::kNT:
|
||||||
|
return func(MMOrderNT(), std::forward<Args>(args)...);
|
||||||
|
case MMOrder::kNT_MT_K:
|
||||||
|
return func(MMOrderNT_MT_K(), std::forward<Args>(args)...);
|
||||||
|
case MMOrder::kNT_MT:
|
||||||
|
return func(MMOrderNT_MT(), std::forward<Args>(args)...);
|
||||||
|
default:
|
||||||
|
HWY_UNREACHABLE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool IsBlock(MMOrder order) {
|
static inline bool IsBlock(MMOrder order) {
|
||||||
return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT;
|
return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT;
|
||||||
}
|
}
|
||||||
|
|
@ -693,26 +737,46 @@ struct MatMulEnv {
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Arguments to MatMul() that are independent of the A/B/C types.
|
// Arguments to MatMul() that are independent of the A/B/C types. Reduces
|
||||||
// Reduces register pressure compared to individual values/references.
|
// register pressure compared to individual values/references. Also used for
|
||||||
|
// passing through `DispatchOrder`.
|
||||||
struct MMArgs {
|
struct MMArgs {
|
||||||
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
|
MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale,
|
||||||
const float* HWY_RESTRICT add, MMOptions options)
|
const float* HWY_RESTRICT add, MMOptions options,
|
||||||
: env(&env),
|
const MMAutoTune<MMConfig>& autotune, const MMConfig& config)
|
||||||
per_key(&per_key),
|
: env(env),
|
||||||
|
line_bytes(env.ctx.cache_info.LineBytes()),
|
||||||
|
|
||||||
|
range_n(0, N),
|
||||||
scale(scale),
|
scale(scale),
|
||||||
add(add),
|
add(add),
|
||||||
options(options),
|
options(options),
|
||||||
line_bytes(env.ctx.cache_info.LineBytes()) {}
|
|
||||||
|
|
||||||
MatMulEnv* env;
|
autotune(autotune),
|
||||||
MMPerKey* per_key;
|
mr(config.MR()),
|
||||||
|
ranges_mc(config.RangesOfMC(M)),
|
||||||
|
ranges_kc(config.RangesOfKC(K)),
|
||||||
|
ranges_nc(config.RangesOfNC(N)),
|
||||||
|
order(config.Order()),
|
||||||
|
inner_tasks(config.InnerTasks()) {}
|
||||||
|
|
||||||
double scale;
|
MatMulEnv& env;
|
||||||
|
const size_t line_bytes; // from `env`, for `Stride`.
|
||||||
|
|
||||||
|
// MatMul arguments:
|
||||||
|
const IndexRange range_n; // entire N
|
||||||
|
const double scale;
|
||||||
const float* HWY_RESTRICT add;
|
const float* HWY_RESTRICT add;
|
||||||
|
const MMOptions options;
|
||||||
|
|
||||||
MMOptions options;
|
const MMAutoTune<MMConfig>& autotune; // for `MaybeEnter`
|
||||||
size_t line_bytes;
|
// From `MMConfig`:
|
||||||
|
const size_t mr;
|
||||||
|
const IndexRangePartition ranges_mc;
|
||||||
|
const IndexRangePartition ranges_kc;
|
||||||
|
const IndexRangePartition ranges_nc;
|
||||||
|
const MMOrder order;
|
||||||
|
const size_t inner_tasks;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||||
|
|
@ -729,11 +793,12 @@ class MMZone {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// `name` must be a string literal.
|
template <class AutoTune>
|
||||||
void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone,
|
void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone,
|
||||||
const MMArgs& args) {
|
const MatMulEnv& env, const AutoTune* auto_tune) {
|
||||||
if (args.per_key->WantProfile()) {
|
// Only if enabled and autotuning finished.
|
||||||
new (&data_) Zone(args.env->ctx.profiler, thread, zone);
|
if (PROFILER_ENABLED && auto_tune->Best()) {
|
||||||
|
new (&data_) Zone(env.ctx.profiler, thread, zone);
|
||||||
HWY_DASSERT(data_ != 0);
|
HWY_DASSERT(data_ != 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -744,7 +809,8 @@ class MMZone {
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
struct MMZone {
|
struct MMZone {
|
||||||
void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {}
|
void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MatMulEnv&,
|
||||||
|
const void*) {}
|
||||||
};
|
};
|
||||||
#endif // PROFILER_ENABLED
|
#endif // PROFILER_ENABLED
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,6 @@
|
||||||
|
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "ops/dot-inl.h"
|
#include "ops/dot-inl.h"
|
||||||
#include "ops/matmul.h"
|
|
||||||
#include "util/mat.h" // MatPtrT
|
#include "util/mat.h" // MatPtrT
|
||||||
#include "hwy/contrib/math/math-inl.h"
|
#include "hwy/contrib/math/math-inl.h"
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
|
|
|
||||||
|
|
@ -40,9 +40,10 @@ class RowPtrs {
|
||||||
public:
|
public:
|
||||||
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {}
|
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {}
|
||||||
|
|
||||||
T* HWY_RESTRICT operator[](size_t row_idx) const {
|
T* HWY_RESTRICT Row(size_t row_idx) const {
|
||||||
return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]);
|
return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]);
|
||||||
}
|
}
|
||||||
|
T* HWY_RESTRICT operator[](size_t row_idx) const { return Row(row_idx); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint8_t** row_ptrs_;
|
uint8_t** row_ptrs_;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue