mirror of https://github.com/google/gemma.cpp.git
Matmul rewrite: fp64 sums, hierarchical parallelization, cache-blocking, autotuning
Remove empty matmul_unit_test. Up to 25 TFLOP/s on 2xZen4 for 512,3072,24576. PiperOrigin-RevId: 729123576
This commit is contained in:
parent
d854471ae2
commit
f9d93e4a42
29
BUILD.bazel
29
BUILD.bazel
|
|
@ -85,6 +85,9 @@ test_suite(
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ops",
|
name = "ops",
|
||||||
|
srcs = [
|
||||||
|
"ops/matmul.cc",
|
||||||
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ops/matmul.h",
|
"ops/matmul.h",
|
||||||
"ops/ops.h",
|
"ops/ops.h",
|
||||||
|
|
@ -103,11 +106,14 @@ cc_library(
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:algo",
|
"@highway//:algo",
|
||||||
|
"@highway//:bit_set",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:math",
|
"@highway//:math",
|
||||||
"@highway//:matvec",
|
"@highway//:matvec",
|
||||||
|
"@highway//:nanobenchmark",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
"@highway//:topology",
|
||||||
"@highway//hwy/contrib/sort:vqsort",
|
"@highway//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -126,6 +132,7 @@ cc_test(
|
||||||
":test_util",
|
":test_util",
|
||||||
":threading",
|
":threading",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//:app",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:test_util",
|
"//compression:test_util",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -151,6 +158,7 @@ cc_test(
|
||||||
":ops",
|
":ops",
|
||||||
":test_util",
|
":test_util",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//:app",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
|
@ -176,26 +184,6 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "matmul_unit_test",
|
|
||||||
size = "small",
|
|
||||||
timeout = "long",
|
|
||||||
srcs = ["ops/matmul_unit_test.cc"],
|
|
||||||
local_defines = ["HWY_IS_TEST"],
|
|
||||||
# for test_suite.
|
|
||||||
tags = ["ops_tests"],
|
|
||||||
deps = [
|
|
||||||
":allocator",
|
|
||||||
":basics",
|
|
||||||
":ops",
|
|
||||||
":test_util",
|
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
|
||||||
"//compression:compress",
|
|
||||||
"@highway//:hwy",
|
|
||||||
"@highway//:hwy_test_util",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "matmul_test",
|
name = "matmul_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
@ -652,6 +640,7 @@ cc_test(
|
||||||
":sampler",
|
":sampler",
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
|
"//:threading",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,8 @@ set(SOURCES
|
||||||
gemma/weights.h
|
gemma/weights.h
|
||||||
ops/dot-inl.h
|
ops/dot-inl.h
|
||||||
ops/matmul-inl.h
|
ops/matmul-inl.h
|
||||||
|
ops/matmul.cc
|
||||||
|
ops/matmul.h
|
||||||
ops/matvec-inl.h
|
ops/matvec-inl.h
|
||||||
ops/ops-inl.h
|
ops/ops-inl.h
|
||||||
ops/ops.h
|
ops/ops.h
|
||||||
|
|
@ -168,7 +170,6 @@ set(GEMMA_TEST_FILES
|
||||||
ops/dot_test.cc
|
ops/dot_test.cc
|
||||||
ops/gemma_matvec_test.cc
|
ops/gemma_matvec_test.cc
|
||||||
ops/matmul_test.cc
|
ops/matmul_test.cc
|
||||||
ops/matmul_unit_test.cc
|
|
||||||
ops/ops_test.cc
|
ops/ops_test.cc
|
||||||
paligemma/image_test.cc
|
paligemma/image_test.cc
|
||||||
paligemma/paligemma_test.cc
|
paligemma/paligemma_test.cc
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "ops/ops.h"
|
#include "ops/ops.h"
|
||||||
|
#include "util/threading.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -58,7 +59,9 @@ void TestMatMulVJP() {
|
||||||
static const size_t kRows = 8;
|
static const size_t kRows = 8;
|
||||||
static const size_t kCols = 64;
|
static const size_t kCols = 64;
|
||||||
static const size_t kTokens = 5;
|
static const size_t kTokens = 5;
|
||||||
hwy::ThreadPool pool(8);
|
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
||||||
|
BoundedSlice(0, 8));
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
MatStorageT<float> weights("weights", kRows, kCols);
|
MatStorageT<float> weights("weights", kRows, kCols);
|
||||||
MatStorageT<float> x("x", kTokens, kCols);
|
MatStorageT<float> x("x", kTokens, kCols);
|
||||||
|
|
@ -85,7 +88,7 @@ void TestMatMulVJP() {
|
||||||
|
|
||||||
grad.ZeroInit();
|
grad.ZeroInit();
|
||||||
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
|
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
|
||||||
grad.data(), dx.data(), pool);
|
grad.data(), dx.data(), pools.Pool());
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
|
|
||||||
|
|
@ -102,7 +105,9 @@ void TestMultiHeadMatMulVJP() {
|
||||||
static const size_t kCols = 16;
|
static const size_t kCols = 16;
|
||||||
static const size_t kHeads = 4;
|
static const size_t kHeads = 4;
|
||||||
static const size_t kTokens = 3;
|
static const size_t kTokens = 3;
|
||||||
hwy::ThreadPool pool(8);
|
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
||||||
|
BoundedSlice(0, 8));
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
|
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
|
||||||
MatStorageT<float> x("x", kTokens, kCols * kHeads);
|
MatStorageT<float> x("x", kTokens, kCols * kHeads);
|
||||||
|
|
@ -130,7 +135,7 @@ void TestMultiHeadMatMulVJP() {
|
||||||
|
|
||||||
grad.ZeroInit();
|
grad.ZeroInit();
|
||||||
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
|
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
|
||||||
kRows, kTokens, grad.data(), dx.data(), pool);
|
kRows, kTokens, grad.data(), dx.data(), pools.Pool());
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
|
|
||||||
|
|
@ -145,7 +150,9 @@ void TestMultiHeadMatMulVJP() {
|
||||||
void TestRMSNormVJP() {
|
void TestRMSNormVJP() {
|
||||||
static const size_t K = 2;
|
static const size_t K = 2;
|
||||||
static const size_t N = 64;
|
static const size_t N = 64;
|
||||||
hwy::ThreadPool pool(8);
|
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
||||||
|
BoundedSlice(0, 8));
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
MatStorageT<float> weights("weights", N, 1);
|
MatStorageT<float> weights("weights", N, 1);
|
||||||
MatStorageT<float> x("x", K, N);
|
MatStorageT<float> x("x", K, N);
|
||||||
|
|
@ -172,7 +179,7 @@ void TestRMSNormVJP() {
|
||||||
|
|
||||||
grad.ZeroInit();
|
grad.ZeroInit();
|
||||||
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
||||||
dx.data(), pool);
|
dx.data(), pools.Pool());
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
|
|
||||||
|
|
@ -209,7 +216,9 @@ static ModelConfig TestConfig() {
|
||||||
|
|
||||||
void TestEndToEnd() {
|
void TestEndToEnd() {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
hwy::ThreadPool pool(0);
|
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
|
||||||
|
BoundedSlice(0, 1));
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
ModelConfig config = TestConfig();
|
ModelConfig config = TestConfig();
|
||||||
WeightsWrapper<float> weights(config);
|
WeightsWrapper<float> weights(config);
|
||||||
WeightsWrapper<float> grad(config);
|
WeightsWrapper<float> grad(config);
|
||||||
|
|
@ -234,13 +243,13 @@ void TestEndToEnd() {
|
||||||
|
|
||||||
float loss1 = CrossEntropyLossForwardPass(
|
float loss1 = CrossEntropyLossForwardPass(
|
||||||
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
||||||
inv_timescale, pool);
|
inv_timescale, pools.Pool());
|
||||||
|
|
||||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||||
|
|
||||||
grad.ZeroInit();
|
grad.ZeroInit();
|
||||||
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
||||||
backward, inv_timescale, pool);
|
backward, inv_timescale, pools.Pool());
|
||||||
|
|
||||||
Complexify(weights.get(), c_weights.get());
|
Complexify(weights.get(), c_weights.get());
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
|
|
|
||||||
|
|
@ -252,6 +252,10 @@ class Gemma {
|
||||||
void GenerateImageTokens(const RuntimeConfig& runtime_config,
|
void GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||||
const Image& image, ImageTokens& image_tokens);
|
const Image& image, ImageTokens& image_tokens);
|
||||||
|
|
||||||
|
void SetMatMulVerbosity(int verbosity) {
|
||||||
|
if (verbosity >= 2) env_.print_best = true;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MatMulEnv env_;
|
MatMulEnv env_;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
}
|
}
|
||||||
if (end_of_turn_seen && abs_pos > 0) {
|
if (end_of_turn_seen && abs_pos > 0) {
|
||||||
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
|
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
|
||||||
// more, because we will pre-pend it again to the prompt in
|
// more, because we will prepend it again to the prompt in
|
||||||
// WrapAndTokenize.
|
// WrapAndTokenize.
|
||||||
abs_pos--;
|
abs_pos--;
|
||||||
}
|
}
|
||||||
|
|
@ -236,14 +236,13 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
// TODO: remove once MatMul is updated.
|
|
||||||
app.max_packages = 1;
|
|
||||||
// Note that num_threads is an upper bound; we also limit to the number of
|
// Note that num_threads is an upper bound; we also limit to the number of
|
||||||
// detected and enabled cores.
|
// detected and enabled cores.
|
||||||
NestedPools pools = CreatePools(app);
|
NestedPools pools = CreatePools(app);
|
||||||
Allocator::Init(pools.Topology());
|
Allocator::Init(pools.Topology());
|
||||||
|
|
||||||
Gemma model = CreateGemma(loader, pools);
|
Gemma model = CreateGemma(loader, pools);
|
||||||
|
model.SetMatMulVerbosity(app.verbosity);
|
||||||
KVCache kv_cache =
|
KVCache kv_cache =
|
||||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -117,17 +117,18 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||||
std::vector<double>& times) {
|
std::vector<double>& times, MMPerKey* per_key) {
|
||||||
std::sort(times.begin(), times.end());
|
std::sort(times.begin(), times.end());
|
||||||
// bench_dnn reports the best and average, but the median seems more
|
// bench_dnn reports the best and average, but the median seems more
|
||||||
// consistent and resistant to outliers.
|
// consistent and resistant to outliers.
|
||||||
const double elapsed = times[times.size() / 2];
|
const double elapsed = times[times.size() / 2];
|
||||||
const double ratio = elapsed / (times[0] + 1E-6); // vs best, avoid / 0
|
const double vs_best = elapsed / (times[0] + 1E-6); // avoid / 0
|
||||||
|
|
||||||
const size_t num_b = B_extents.Area();
|
const size_t num_b = B_extents.Area();
|
||||||
// FMA counts as two FLOP.
|
const double flops = 2 * A_extents.rows * num_b / elapsed; // FMA = 2 ops
|
||||||
fprintf(stderr, "%.1f\t(med %.3f ms = %0.2fx min)\n",
|
|
||||||
2 * 1E-9 * A_extents.rows * num_b / elapsed, elapsed * 1E3, ratio);
|
fprintf(stderr, "\t%.1f GFLOPS %.3f ms %0.2fx\n", flops * 1E-9, elapsed * 1E3,
|
||||||
|
vs_best);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates inputs and prints observed throughput of MatMul.
|
// Generates inputs and prints observed throughput of MatMul.
|
||||||
|
|
@ -135,15 +136,18 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||||
template <typename MatTA, typename MatTB = MatTA>
|
template <typename MatTA, typename MatTB = MatTA>
|
||||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
|
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
|
||||||
fprintf(stderr, "\nBenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
if (env.print_config || env.print_measurement) {
|
||||||
M, K, N, add, TypeName<MatTA>(), TypeName<MatTB>());
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
|
fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N,
|
||||||
|
add, TypeName<MatTA>(), TypeName<MatTB>());
|
||||||
|
|
||||||
const Extents2D A_extents(M, K);
|
const Extents2D A_extents(M, K);
|
||||||
const Extents2D B_extents(N, K); // already transposed
|
const Extents2D B_extents(N, K); // already transposed
|
||||||
const Extents2D C_extents(M, N);
|
const Extents2D C_extents(M, N);
|
||||||
|
|
||||||
RowVectorBatch<float> c_slow_batch(C_extents);
|
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
|
||||||
RowVectorBatch<float> c_batch(C_extents);
|
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
|
||||||
|
|
||||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||||
if (add) {
|
if (add) {
|
||||||
|
|
@ -161,27 +165,40 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||||
|
|
||||||
constexpr size_t kSamples = 20;
|
// Fewer reps for large batch sizes, which take longer.
|
||||||
|
const size_t num_samples = M < 32 ? 20 : 12;
|
||||||
std::vector<double> times;
|
std::vector<double> times;
|
||||||
times.reserve(kSamples);
|
times.reserve(num_samples);
|
||||||
|
|
||||||
|
// Ensure usage conditions are set before autotuning. Both binding and
|
||||||
|
// spinning may materially affect the choice of config. No harm in calling
|
||||||
|
// BindB/C if there is a single package: they will be a no-op.
|
||||||
|
BindB(B_extents.rows, B, env.parallel);
|
||||||
|
BindC(A_extents.rows, C, env.parallel);
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
env.parallel.Pools().MaybeStartSpinning(use_spinning);
|
env.parallel.Pools().MaybeStartSpinning(use_spinning);
|
||||||
|
|
||||||
|
// env.print_config = true;
|
||||||
|
// env.print_measurement = true;
|
||||||
|
env.print_best = true;
|
||||||
|
|
||||||
double keep = 0.0;
|
double keep = 0.0;
|
||||||
|
MMPerKey* per_key;
|
||||||
// Until enough samples collected *after* autotuning finished:
|
// Until enough samples collected *after* autotuning finished:
|
||||||
while (times.size() < kSamples) {
|
while (times.size() < num_samples) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
MatMul(A, B, add_row, env, C);
|
per_key = MatMul(A, B, add_row, env, C);
|
||||||
const double t1 = hwy::platform::Now();
|
const double t1 = hwy::platform::Now();
|
||||||
double elapsed = t1 - t0;
|
double elapsed = t1 - t0;
|
||||||
keep += C.Row(0)[hwy::Unpredictable1()];
|
keep += C.Row(0)[hwy::Unpredictable1()];
|
||||||
|
|
||||||
times.push_back(elapsed);
|
// Only record times after autotuning finished.
|
||||||
|
if (per_key->autotune.Best()) times.push_back(elapsed);
|
||||||
}
|
}
|
||||||
hwy::PreventElision(keep);
|
hwy::PreventElision(keep);
|
||||||
env.parallel.Pools().MaybeStopSpinning(use_spinning);
|
env.parallel.Pools().MaybeStopSpinning(use_spinning);
|
||||||
PrintSpeed(A_extents, B_extents, times);
|
PrintSpeed(A_extents, B_extents, times, per_key);
|
||||||
}
|
}
|
||||||
|
|
||||||
using F32 = float;
|
using F32 = float;
|
||||||
|
|
@ -189,29 +206,31 @@ using SFP = SfpStream;
|
||||||
|
|
||||||
void BenchAllMatMul() {
|
void BenchAllMatMul() {
|
||||||
if (first_target == 0) first_target = HWY_TARGET;
|
if (first_target == 0) first_target = HWY_TARGET;
|
||||||
if (HWY_TARGET != first_target) return;
|
// Disable the best-target-only limitation.
|
||||||
|
// if (HWY_TARGET != first_target) return;
|
||||||
|
|
||||||
for (size_t max_packages : {/*1,*/ 2}) {
|
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
||||||
const size_t max_threads = 0; // no limit
|
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||
|
||||||
NestedPools pools(max_threads, Tristate::kDefault,
|
HWY_TARGET == HWY_SSE2) {
|
||||||
BoundedSlice(0, max_packages));
|
return;
|
||||||
#if GEMMA_DISABLE_TOPOLOGY
|
}
|
||||||
if (max_packages == 2) break; // we only have one package
|
|
||||||
#else
|
|
||||||
// If less than the limit, we have already tested all num_packages.
|
|
||||||
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
|
|
||||||
#endif
|
|
||||||
fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages,
|
|
||||||
pools.TopologyString(), pools.PinString());
|
|
||||||
|
|
||||||
Allocator::Init(pools.Topology());
|
const size_t max_threads = 0; // no limit
|
||||||
MatMulEnv env(pools);
|
const BoundedSlice package_slice; // all packages/sockets
|
||||||
|
const BoundedSlice cluster_slice; // all clusters/CCX
|
||||||
|
const BoundedSlice lp_slice; // default to all cores (per package).
|
||||||
|
NestedPools pools(max_threads, Tristate::kDefault, package_slice,
|
||||||
|
cluster_slice, lp_slice);
|
||||||
|
fprintf(stderr, "BenchAllMatMul %s %s\n", pools.TopologyString(),
|
||||||
|
pools.PinString());
|
||||||
|
|
||||||
for (size_t batch_size : {1, 4, 128, 512}) {
|
Allocator::Init(pools.Topology(), /*enable_bind=*/true);
|
||||||
constexpr bool kAdd = false;
|
MatMulEnv env(pools);
|
||||||
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
|
|
||||||
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
|
for (size_t batch_size : {1, 4, 128, 512}) {
|
||||||
}
|
constexpr bool kAdd = false;
|
||||||
|
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
|
||||||
|
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
PROFILER_PRINT_RESULTS();
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
#include "util/app.h"
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -999,6 +1000,8 @@ struct TestShortDotsT {
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
const hn::ScalableTag<float> df; // for CallDot
|
const hn::ScalableTag<float> df; // for CallDot
|
||||||
|
|
||||||
|
NestedPools pools = CreatePools(AppArgs());
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
CompressWorkingSet work;
|
CompressWorkingSet work;
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
rng.seed(12345);
|
rng.seed(12345);
|
||||||
|
|
|
||||||
1699
ops/matmul-inl.h
1699
ops/matmul-inl.h
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,415 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "ops/matmul.h"
|
||||||
|
|
||||||
|
// Analytical model of cache parameters for generating autotune candidates.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "util/allocator.h"
|
||||||
|
#include "util/basics.h"
|
||||||
|
#include "util/threading.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/detect_targets.h"
|
||||||
|
#include "hwy/per_target.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Rounds down to a multiple of `multiple`, but returns at least `multiple`.
|
||||||
|
size_t RoundDownWithFloor(size_t value, size_t multiple) {
|
||||||
|
HWY_DASSERT(multiple != 0);
|
||||||
|
return HWY_MAX(multiple, hwy::RoundDownTo(value, multiple));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the highest number in `[begin, end)` that divides `dim` and is a
|
||||||
|
// multiple of `multiple`, or 0 if none exists.
|
||||||
|
size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
||||||
|
const size_t multiple) {
|
||||||
|
HWY_DASSERT(end != 0 && dim != 0 && multiple != 0);
|
||||||
|
size_t prev = RoundDownWithFloor(end, multiple);
|
||||||
|
// Avoid returning `end` if rounding down had no effect.
|
||||||
|
if (prev == end) prev -= multiple;
|
||||||
|
for (;;) {
|
||||||
|
if (prev == 0) return 0; // No divisor if large multiple or small end.
|
||||||
|
if (dim % prev == 0) return prev;
|
||||||
|
if (prev <= begin) return 0;
|
||||||
|
prev -= multiple;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation of `MMCandidates`. Class hides the `KC` etc member functions
|
||||||
|
// and holds most of their arguments in member variables.
|
||||||
|
class GenerateCandidates {
|
||||||
|
public:
|
||||||
|
GenerateCandidates(size_t M, size_t K, size_t N, size_t max_mr, size_t nr,
|
||||||
|
const IndexRangePartition& ranges_np, bool print_config)
|
||||||
|
: M_(M),
|
||||||
|
K_(K),
|
||||||
|
max_mr_(max_mr),
|
||||||
|
nr_(nr),
|
||||||
|
// These influence kc/nc, but are also stored in `MMConfig` for
|
||||||
|
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
||||||
|
// is likely still in L1, but we expect K > 1000 and might as well round
|
||||||
|
// up to the line size.
|
||||||
|
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
|
||||||
|
nc_multiple_(Allocator::StepBytes() / sizeof(float)),
|
||||||
|
ranges_np_(ranges_np),
|
||||||
|
print_config_(print_config) {}
|
||||||
|
|
||||||
|
std::vector<MMConfig> operator()() const {
|
||||||
|
std::vector<MMConfig> candidates;
|
||||||
|
candidates.reserve(128);
|
||||||
|
|
||||||
|
for (size_t mr : MR()) {
|
||||||
|
for (MMOrder order : Orders(mr)) {
|
||||||
|
const std::vector<int>& all_inner_tasks = InnerTasks(order);
|
||||||
|
const std::vector<MMOut>& all_outs = Outs(order);
|
||||||
|
for (size_t kc : KC(mr, order)) {
|
||||||
|
for (size_t mc : MC(mr, kc, order)) {
|
||||||
|
for (size_t nc : NC(mr, mc, kc, order)) {
|
||||||
|
for (int inner_tasks : all_inner_tasks) {
|
||||||
|
for (MMOut out : all_outs) {
|
||||||
|
const MMConfig config(K_, mr, mc, kc, nc, kc_multiple_,
|
||||||
|
nc_multiple_, order, out, inner_tasks);
|
||||||
|
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
|
||||||
|
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
|
||||||
|
|
||||||
|
// Blocks only make sense when there are multiple M tasks.
|
||||||
|
if (IsBlock(order) != (M_tasks > 1)) continue;
|
||||||
|
// Single KC only makes sense when there is a single K task.
|
||||||
|
if (IsOneKC(order) != (K_tasks == 1)) continue;
|
||||||
|
|
||||||
|
candidates.push_back(config);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_ASSERT(!candidates.empty());
|
||||||
|
return candidates;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
using SizeVec = std::vector<size_t>;
|
||||||
|
|
||||||
|
// How many rows of A per call to `MMKernel::LoopOverKC`. Lower values may
|
||||||
|
// be better for SIMD targets with fewer registers.
|
||||||
|
SizeVec MR() const {
|
||||||
|
const int64_t target = hwy::DispatchedTarget();
|
||||||
|
const bool is_avx2 = target == HWY_AVX2;
|
||||||
|
const bool is_sse = HWY_SSE4 <= target && target <= HWY_SSE2;
|
||||||
|
const bool is_wasm = target == HWY_WASM || target == HWY_WASM_EMU256;
|
||||||
|
|
||||||
|
SizeVec all_mr;
|
||||||
|
all_mr.reserve(3);
|
||||||
|
// AVX2's 16 registers are not enough for four rows, but SSE4 may benefit.
|
||||||
|
if (M_ >= max_mr_ && !is_avx2) all_mr.push_back(max_mr_);
|
||||||
|
// Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also
|
||||||
|
// enable if not enough rows for 4.
|
||||||
|
if (M_ >= 2 && (M_ < max_mr_ || (!is_sse && !is_wasm))) {
|
||||||
|
all_mr.push_back(size_t{2});
|
||||||
|
}
|
||||||
|
// Even SSE4 usually prefers 2 rows; only enable for single rows.
|
||||||
|
if (M_ == 1) all_mr.push_back(size_t{1});
|
||||||
|
HWY_ASSERT(!all_mr.empty());
|
||||||
|
return all_mr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Which loop orders to enable depending on M.
|
||||||
|
std::vector<MMOrder> Orders(size_t mr) const {
|
||||||
|
std::vector<MMOrder> orders;
|
||||||
|
for (size_t order_idx = 0;; ++order_idx) {
|
||||||
|
const MMOrder order = static_cast<MMOrder>(order_idx);
|
||||||
|
if (StringFromOrder(order) == nullptr) return orders; // done
|
||||||
|
// 2D blocking is useless for a single row of M.
|
||||||
|
if (IsBlock(order) && M_ <= mr) continue;
|
||||||
|
// Conversely, N-only parallelism is uncompetitive for large M.
|
||||||
|
if (!IsBlock(order) && M_ >= 8 * mr) continue;
|
||||||
|
orders.push_back(order);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number of A and B columns to read between updating `partial`.
|
||||||
|
SizeVec KC(size_t mr, MMOrder order) const {
|
||||||
|
// `LoopOverKC` handles up to `mr` rows of A.
|
||||||
|
const size_t rows_a = HWY_MIN(M_, mr);
|
||||||
|
|
||||||
|
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
|
||||||
|
// `buf`. To amortize the write cost, we want to maximize `kc`. However, it
|
||||||
|
// is important that B fits in L1, because batch=1 only has a single row of
|
||||||
|
// A and thus no reuse of the packed B. When L1-resident, we can use the
|
||||||
|
// separate `DecompressAndZeroPad` to write `kc` columns, rather than having
|
||||||
|
// to integrate `Decompress2` into `LoopOverKC`, which is less efficient for
|
||||||
|
// TB=NUQ due to less amortization of the table loads. Due to the low L1
|
||||||
|
// latency, the packing is still effectively fused into `LoopOverKC`. It may
|
||||||
|
// be better to round up and accept a few L2 accesses in exchange for
|
||||||
|
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
|
||||||
|
// subtract the output and buf, and allow using more than the actual L1
|
||||||
|
// size. This results in an overestimate, and the loop below will propose
|
||||||
|
// the next few smaller values for the autotuner to evaluate.
|
||||||
|
const size_t bytes_ab = Allocator::L1Bytes() * 3;
|
||||||
|
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
|
||||||
|
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
|
||||||
|
kc_max =
|
||||||
|
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
|
||||||
|
kc_max = HWY_MIN(kc_max, K_);
|
||||||
|
|
||||||
|
SizeVec all_kc(1, kc_max);
|
||||||
|
|
||||||
|
// Avoid proposing kc > K.
|
||||||
|
if (K_ > kc_multiple_) {
|
||||||
|
// Generally it is best to use the full `kc` (fewer writes to `partial`),
|
||||||
|
// but a bit less can be better if it evenly divides `K`, or enables an
|
||||||
|
// `mc` that evenly divides `M`. Try several smaller values.
|
||||||
|
|
||||||
|
// If we can afford a single K task, that's usually best; only try one
|
||||||
|
// more. Otherwise, blocks may require smaller kc (more options).
|
||||||
|
const size_t reps = (kc_max == K_) ? 1 : IsBlock(order) ? 3 : 2;
|
||||||
|
|
||||||
|
size_t prev = kc_max;
|
||||||
|
for (size_t rep = 0; rep < reps; ++rep) {
|
||||||
|
const size_t div = PrevDivisor(kc_multiple_, prev, K_, kc_multiple_);
|
||||||
|
prev = div ? div : RoundDownWithFloor(prev / 2, kc_multiple_);
|
||||||
|
all_kc.push_back(prev);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (print_config_ && all_kc.size() > 1) {
|
||||||
|
fprintf(stderr, "KC: ");
|
||||||
|
for (size_t kc : all_kc) {
|
||||||
|
fprintf(stderr, "%zu ", kc);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
return all_kc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number of (L2 resident) A rows for `A2C0` to loop over.
|
||||||
|
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
|
||||||
|
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
|
||||||
|
// it is typically inclusive.
|
||||||
|
const size_t bytes_b = nr_ * kc * (sizeof(SfpStream) + sizeof(BF16));
|
||||||
|
|
||||||
|
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
|
||||||
|
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||||
|
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
||||||
|
// partial.
|
||||||
|
const size_t bytes_per_mc = kc * sizeof(BF16) + Allocator::LineBytes();
|
||||||
|
size_t mc_max = hwy::DivCeil(Allocator::L2Bytes() - bytes_b, bytes_per_mc);
|
||||||
|
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
|
||||||
|
HWY_DASSERT(mc_max != 0);
|
||||||
|
mc_max = HWY_MIN(mc_max, M_);
|
||||||
|
mc_max = hwy::RoundDownTo(mc_max, mr);
|
||||||
|
|
||||||
|
SizeVec all_mc(1, mc_max);
|
||||||
|
// Larger MC is better for non-blocks, otherwise we want more small options.
|
||||||
|
const size_t reps = !IsBlock(order) ? 2 : 3;
|
||||||
|
|
||||||
|
size_t prev = mc_max;
|
||||||
|
for (size_t rep = 0; rep < reps; ++rep) {
|
||||||
|
prev = PrevDivisor(1, prev, M_, mr);
|
||||||
|
if (prev >= mc_max || prev == 0) break;
|
||||||
|
all_mc.push_back(prev);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Blocks: largest is not useful.
|
||||||
|
if (IsBlock(order) && all_mc.size() > 1) {
|
||||||
|
all_mc.erase(all_mc.begin(), all_mc.begin() + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (print_config_ && all_mc.size() > 1) {
|
||||||
|
fprintf(stderr, "MC: ");
|
||||||
|
for (size_t mc : all_mc) {
|
||||||
|
fprintf(stderr, "%zu ", mc);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
return all_mc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number of (possibly L3 resident) B rows per `NT_MT` task.
|
||||||
|
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
|
||||||
|
const size_t np_max = ranges_np_.TaskSize();
|
||||||
|
size_t nc_max = np_max;
|
||||||
|
const size_t out_bytes = IsOneKC(order) ? sizeof(float) : sizeof(double);
|
||||||
|
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
|
||||||
|
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
|
||||||
|
// Otherwise, leave it unbounded.
|
||||||
|
if (M_ > mr) {
|
||||||
|
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
|
||||||
|
nc_max = hwy::DivCeil(Allocator::L3Bytes(), bytes_per_nc);
|
||||||
|
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
|
||||||
|
}
|
||||||
|
HWY_DASSERT(nc_max != 0);
|
||||||
|
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
|
||||||
|
|
||||||
|
// If there are going to be multiple ranges, anything more than half would
|
||||||
|
// be imbalanced and suboptimal.
|
||||||
|
if (nc_max < np_max && nc_max >= np_max / 2) {
|
||||||
|
nc_max = RoundDownWithFloor(np_max / 2, nc_multiple_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-block calls ForNP, which ignores `range_nc` and uses `range_np`.
|
||||||
|
if (!IsBlock(order)) return SizeVec(1, np_max);
|
||||||
|
|
||||||
|
SizeVec all_nc(1, nc_max);
|
||||||
|
|
||||||
|
// Avoid proposing nc > N.
|
||||||
|
if (np_max > nc_multiple_) {
|
||||||
|
// Large L3, but its behavior and characteristics varies across platforms,
|
||||||
|
// hence autotune a wider range of nc than the other dimensions.
|
||||||
|
size_t reps = 10;
|
||||||
|
// For small M, we can afford larger NC, hence allow fewer small options.
|
||||||
|
if (M_ <= 2 * mr) reps -= 1;
|
||||||
|
|
||||||
|
size_t prev = nc_max;
|
||||||
|
for (size_t rep = 0; rep < reps; ++rep) {
|
||||||
|
const size_t div =
|
||||||
|
PrevDivisor(nc_multiple_, prev, np_max, nc_multiple_);
|
||||||
|
prev = div ? div : RoundDownWithFloor(prev / 2, nc_multiple_);
|
||||||
|
all_nc.push_back(prev);
|
||||||
|
if (prev == nc_multiple_) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip the larger values (unlikely to be chosen), keep about 40%.
|
||||||
|
const ptrdiff_t want_delete =
|
||||||
|
static_cast<ptrdiff_t>(all_nc.size() * 5 / 9 + 2);
|
||||||
|
// Keep at least 2.
|
||||||
|
const ptrdiff_t max_delete =
|
||||||
|
HWY_MAX(static_cast<ptrdiff_t>(all_nc.size()) - 2, ptrdiff_t{0});
|
||||||
|
all_nc.erase(all_nc.begin(),
|
||||||
|
all_nc.begin() + HWY_MIN(want_delete, max_delete));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (print_config_ && all_nc.size() > 1) {
|
||||||
|
fprintf(stderr, "NC: ");
|
||||||
|
for (size_t nc : all_nc) {
|
||||||
|
fprintf(stderr, "%zu ", nc);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
return all_nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// How many tasks per cluster worker. More = smaller tasks, which can lead
|
||||||
|
// to better load balancing at the cost of higher overhead.
|
||||||
|
std::vector<int> InnerTasks(MMOrder order) const {
|
||||||
|
std::vector<int> inner_tasks;
|
||||||
|
inner_tasks.reserve(3);
|
||||||
|
inner_tasks.push_back(1);
|
||||||
|
// Blocks have one task per mc/nc range and ignore this parameter.
|
||||||
|
if (!IsBlock(order)) {
|
||||||
|
inner_tasks.push_back(2);
|
||||||
|
inner_tasks.push_back(4);
|
||||||
|
}
|
||||||
|
return inner_tasks;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Whether to parallelize FillC or enable direct writes to C.
|
||||||
|
std::vector<MMOut> Outs(MMOrder order) const {
|
||||||
|
std::vector<MMOut> outs;
|
||||||
|
for (size_t out_idx = 0;; ++out_idx) {
|
||||||
|
const MMOut out = static_cast<MMOut>(out_idx);
|
||||||
|
if (StringFromOut(out) == nullptr) return outs; // done
|
||||||
|
// kParM only makes sense if we have more than one row of A.
|
||||||
|
if (out == MMOut::kParM && M_ == 1) continue;
|
||||||
|
// Blocks are already parallelized.
|
||||||
|
if (out == MMOut::kParM && IsBlock(order)) continue;
|
||||||
|
// Direct only works for a single kc range.
|
||||||
|
if ((out == MMOut::kDirect) != IsOneKC(order)) continue;
|
||||||
|
// For non-block, kCopy does not beat kDirect.
|
||||||
|
if (out == MMOut::kCopy && IsOneKC(order) && !IsBlock(order)) continue;
|
||||||
|
outs.push_back(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t M_;
|
||||||
|
const size_t K_;
|
||||||
|
|
||||||
|
const size_t max_mr_;
|
||||||
|
const size_t nr_;
|
||||||
|
|
||||||
|
const size_t kc_multiple_;
|
||||||
|
const size_t nc_multiple_;
|
||||||
|
|
||||||
|
IndexRangePartition ranges_np_;
|
||||||
|
|
||||||
|
const bool print_config_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||||
|
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
|
||||||
|
size_t nr,
|
||||||
|
const IndexRangePartition& ranges_np,
|
||||||
|
bool print_config) {
|
||||||
|
return GenerateCandidates(M, K, N, max_mr, nr, ranges_np, print_config)();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
||||||
|
// memory accesses or false sharing, unless there are insufficient per-package
|
||||||
|
// rows for that.
|
||||||
|
static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) {
|
||||||
|
size_t np_multiple = Allocator::QuantumBytes() / sizeof(float);
|
||||||
|
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
||||||
|
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
|
||||||
|
// choose a smaller multiple.
|
||||||
|
if (N % (np_multiple * num_packages)) {
|
||||||
|
const size_t min_multiple = Allocator::LineBytes() / sizeof(float);
|
||||||
|
np_multiple =
|
||||||
|
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
|
||||||
|
if (HWY_UNLIKELY(np_multiple == 0)) {
|
||||||
|
np_multiple = min_multiple;
|
||||||
|
}
|
||||||
|
// This happens in tests with small N, hence do not assert.
|
||||||
|
if (N % (np_multiple * num_packages) && N >= 128) {
|
||||||
|
HWY_WARN("NPMultiple: N=%zu still not divisible by np_multiple=%zu\n", N,
|
||||||
|
np_multiple);
|
||||||
|
np_multiple = nr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return np_multiple;
|
||||||
|
}
|
||||||
|
|
||||||
|
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
||||||
|
size_t nr) const {
|
||||||
|
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
|
||||||
|
return StaticPartition(IndexRange(0, N), num_packages,
|
||||||
|
NPMultiple(N, nr, num_packages));
|
||||||
|
}
|
||||||
|
|
||||||
|
MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) {
|
||||||
|
// Ensure Allocator:Init was called.
|
||||||
|
HWY_ASSERT(Allocator::LineBytes() != 0 && Allocator::VectorBytes() != 0);
|
||||||
|
|
||||||
|
char cpu100[100];
|
||||||
|
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
710
ops/matmul.h
710
ops/matmul.h
|
|
@ -16,87 +16,668 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||||
|
|
||||||
|
// Non-SIMD part of MatMul: parallelization, allocation, and autotuning.
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/bit_set.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
#include "hwy/per_target.h" // VectorBytes
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// TODO: remove deprecated typedef.
|
|
||||||
using Range1D = IndexRange;
|
|
||||||
|
|
||||||
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
|
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
|
||||||
// loads, we reuse the same A row for several B columns, which are also loaded
|
// loads, we reuse the same A row for several B columns, which are also loaded
|
||||||
// once for several rows of C. Thus we produce one 'tile' of C at a time of
|
// once for several rows of C. Thus we produce one 'tile' of C at a time of
|
||||||
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are
|
// dimensions `mr (<= kMaxMR)` x `kNR`. To keep FMA units busy, this should be
|
||||||
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4
|
// at least the product of the FMA latency (3..5) times the throughput (2).
|
||||||
// enables the `StoreInterleaved4` transpose in `StoreHorizontalSums`. We assume
|
// This and `mr` are limited by the number of registers, which is generally
|
||||||
// and verify that `C.Cols() % kRegCols == 0`.
|
// 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in
|
||||||
constexpr size_t kRegCols = 4;
|
// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`.
|
||||||
|
constexpr size_t kNR = 4;
|
||||||
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
|
|
||||||
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
|
|
||||||
// In general, `batch_size` (A/C rows) is not a multiple of `kRegRows`. Thus
|
|
||||||
// functions that load or store a tile are parameterized on `kRowsPerTile`:
|
|
||||||
// usually `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
|
|
||||||
constexpr size_t kRegRows = kRegCols;
|
|
||||||
|
|
||||||
struct CacheSizes {
|
|
||||||
CacheSizes() = default;
|
|
||||||
CacheSizes(const BoundedTopology::Cluster& cluster) {
|
|
||||||
// Assumes each package and cluster has the same cache sizes, and uses
|
|
||||||
// reasonable defaults if unknown.
|
|
||||||
l1_bytes = 32 * 1024; // typical size, rarely changes
|
|
||||||
l2_bytes = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) * 1024;
|
|
||||||
l3_bytes = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) * 1024;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t l1_bytes;
|
|
||||||
size_t l2_bytes;
|
|
||||||
size_t l3_bytes;
|
|
||||||
};
|
|
||||||
|
|
||||||
class MMParallel {
|
class MMParallel {
|
||||||
public:
|
public:
|
||||||
MMParallel() : pools_(nullptr) {}
|
static constexpr size_t kMaxPackages = 4;
|
||||||
explicit MMParallel(NestedPools& pools) : pools_(&pools) {}
|
|
||||||
|
|
||||||
NestedPools& Pools() const { return *pools_; }
|
MMParallel(NestedPools& pools) : pools_(pools) {
|
||||||
hwy::ThreadPool& Pool() const { return pools_->Pool(); }
|
HWY_DASSERT(pools_.NumPackages() <= kMaxPackages);
|
||||||
|
|
||||||
private:
|
|
||||||
NestedPools* pools_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Allocations and threads, shared across MatMul calls.
|
|
||||||
class MatMulEnv {
|
|
||||||
public:
|
|
||||||
explicit MatMulEnv(NestedPools& pools) : parallel(pools) {
|
|
||||||
const size_t N = hwy::VectorBytes() / sizeof(float);
|
|
||||||
buf_ = RowVectorBatch<float>(Extents2D(pools.MaxWorkers(), 16 * N));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RowVectorBatch<float>& Buf() { return buf_; }
|
// Used by tests.
|
||||||
|
NestedPools& Pools() { return pools_; }
|
||||||
|
|
||||||
MMParallel parallel;
|
// Initial static partitioning of B rows across packages.
|
||||||
|
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
||||||
|
size_t nr) const;
|
||||||
|
|
||||||
// TODO: remove once no longer used.
|
// For `BindB` and `BindC`.
|
||||||
NestedPools& Pools() const { return parallel.Pools(); }
|
size_t Node(size_t pkg_idx) const {
|
||||||
hwy::ThreadPool& Pool() const { return parallel.Pool(); }
|
return pools_.Topology().GetCluster(pkg_idx, 0).Node();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calls `func(pkg_idx)` for each package in parallel.
|
||||||
|
template <class Func>
|
||||||
|
void ForPkg(const size_t max_packages, const Func& func) {
|
||||||
|
pools_.AllPackages().Run(0, HWY_MIN(max_packages, pools_.NumPackages()),
|
||||||
|
[&](uint64_t task, size_t pkg_idx) {
|
||||||
|
HWY_DASSERT(task == pkg_idx);
|
||||||
|
(void)task;
|
||||||
|
func(pkg_idx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||||
|
// the granularity of per-cluster tasks. Calls `func(worker_range)`.
|
||||||
|
template <class Func>
|
||||||
|
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
|
||||||
|
size_t pkg_idx, const Func& func) {
|
||||||
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
|
// Single cluster: parallel-for over static partition of `range_np`.
|
||||||
|
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
|
||||||
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
|
if (num_clusters == 1) {
|
||||||
|
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, 0);
|
||||||
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
|
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
|
return ParallelizeOneRange(
|
||||||
|
worker_ranges, cluster,
|
||||||
|
[&](const IndexRange& worker_range, size_t /*thread*/) {
|
||||||
|
func(worker_range);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign each cluster a sub-range of `range_np` (typically hundreds).
|
||||||
|
const IndexRangePartition nx_ranges =
|
||||||
|
StaticPartition(range_np, num_clusters, nx_multiple);
|
||||||
|
ParallelizeOneRange(
|
||||||
|
nx_ranges, all_clusters,
|
||||||
|
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
||||||
|
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
||||||
|
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||||
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
|
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
|
ParallelizeOneRange(worker_ranges, cluster,
|
||||||
|
[&](const IndexRange& worker_range,
|
||||||
|
size_t /*thread*/) { func(worker_range); });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
||||||
|
// rows). Calls `func(range_mc, range_nc)`.
|
||||||
|
template <class Func>
|
||||||
|
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
|
||||||
|
const IndexRangePartition& ranges_nc, size_t pkg_idx,
|
||||||
|
const Func& func) {
|
||||||
|
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
|
||||||
|
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||||
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
|
// Single (big) cluster: collapse two range indices into one parallel-for
|
||||||
|
// to reduce the number of fork-joins.
|
||||||
|
if (num_clusters == 1) {
|
||||||
|
const size_t cluster_idx = 0;
|
||||||
|
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
||||||
|
// Low-batch: avoid Divide/Remainder.
|
||||||
|
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||||
|
return ParallelizeOneRange(
|
||||||
|
ranges_nc, cluster,
|
||||||
|
[&](const IndexRange& range_nc, size_t /*thread*/) {
|
||||||
|
func(ranges_mc.Range(0), range_nc);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
return ParallelizeTwoRanges(
|
||||||
|
ranges_mc, ranges_nc, cluster,
|
||||||
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
|
size_t /*thread*/) { func(range_mc, range_nc); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple clusters: N across clusters (both are usually the larger), and
|
||||||
|
// M within each cluster. We assume auto-tuning finds small MC/NC tasks.
|
||||||
|
ParallelizeOneRange(
|
||||||
|
ranges_nc, all_clusters,
|
||||||
|
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||||
|
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
||||||
|
ParallelizeOneRange(
|
||||||
|
ranges_mc, cluster,
|
||||||
|
[&](const IndexRange& range_mc, size_t /*thread*/) {
|
||||||
|
func(range_mc, range_nc);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calls `func(row_a)` in parallel.
|
||||||
|
template <class Func>
|
||||||
|
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
|
||||||
|
const Func& func) {
|
||||||
|
pools_.Pool(pkg_idx).Run(
|
||||||
|
range_mc.begin(), range_mc.end(),
|
||||||
|
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); });
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
RowVectorBatch<float> buf_;
|
NestedPools& pools_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T> // float for C, double for partial
|
||||||
|
void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) {
|
||||||
|
if (!Allocator::ShouldBind()) return;
|
||||||
|
|
||||||
|
const IndexRangePartition ranges_np =
|
||||||
|
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), kNR);
|
||||||
|
const size_t quantum = Allocator::QuantumBytes() / sizeof(T);
|
||||||
|
bool ok = true;
|
||||||
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
|
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||||
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
|
for (size_t im = 0; im < M; ++im) {
|
||||||
|
// BindRowsToPackageNodes may not be page-aligned.
|
||||||
|
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
|
||||||
|
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
|
||||||
|
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(T),
|
||||||
|
node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (HWY_UNLIKELY(!ok)) {
|
||||||
|
HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", M, C.Cols(),
|
||||||
|
ranges_np.NumTasks());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-package storage for packed A, and one global C-shaped `partial` for
|
||||||
|
// accumulating partial dot products (sections of K).
|
||||||
|
class MMStorage {
|
||||||
|
public:
|
||||||
|
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
|
||||||
|
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
|
||||||
|
// per package and 512 MiB, respectively.
|
||||||
|
static constexpr size_t kMaxM = 2048;
|
||||||
|
static constexpr size_t kMaxK = 64 * 1024;
|
||||||
|
static constexpr size_t kMaxN = 256 * 1024;
|
||||||
|
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||||
|
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||||
|
static constexpr size_t kMaxKC = 8 * 1024;
|
||||||
|
|
||||||
|
explicit MMStorage(MMParallel& parallel) {
|
||||||
|
// Per-package allocation so each can decompress A into its own copy.
|
||||||
|
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||||
|
pkg_A_[pkg_idx] = AllocateAlignedRows<BF16>(Extents2D(kMaxM, kMaxK));
|
||||||
|
|
||||||
|
if (Allocator::ShouldBind()) {
|
||||||
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
|
if (!Allocator::BindMemory(pkg_A_[pkg_idx].All(),
|
||||||
|
pkg_A_[pkg_idx].NumBytes(), node)) {
|
||||||
|
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||||
|
// one instance of the maximum matrix extents because threads write at
|
||||||
|
// false-sharing-free granularity.
|
||||||
|
partial_storage_ = AllocateAlignedRows<double>(Extents2D(kMaxM, kMaxN));
|
||||||
|
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||||
|
partial_ = RowPtrD(partial_storage_.All(), kMaxN,
|
||||||
|
StrideForCyclicOffsets<double>(kMaxN));
|
||||||
|
// Avoid cross-package accesses.
|
||||||
|
BindC(kMaxM, partial_, parallel);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
|
||||||
|
// non-const, because `RowPtr` requires a non-const pointer.
|
||||||
|
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) {
|
||||||
|
HWY_DASSERT(extents.rows <= kMaxM);
|
||||||
|
HWY_DASSERT(extents.cols <= kMaxK);
|
||||||
|
const size_t stride = StrideForCyclicOffsets<BF16>(extents.cols);
|
||||||
|
return RowPtrBF(pkg_A_[pkg_idx].All(), extents.cols, stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
RowPtrD Partial() const { return partial_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
RowVectorBatch<BF16> pkg_A_[MMParallel::kMaxPackages];
|
||||||
|
RowVectorBatch<double> partial_storage_;
|
||||||
|
RowPtrD partial_;
|
||||||
|
};
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// Autotuning
|
||||||
|
|
||||||
|
// Naming convention: outer loop first, T suffix means threaded. This refers to
|
||||||
|
// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost
|
||||||
|
// `ranges_np` loop across packages is implicit and applies to all of these.
|
||||||
|
//
|
||||||
|
// Parallelizing across K (A/B columns) is undesirable because the resulting
|
||||||
|
// partial dot products require synchronization or reduction across threads.
|
||||||
|
enum class MMOrder : uint8_t {
|
||||||
|
// Single M, parallel N, sequential K (inside the parallel section to
|
||||||
|
// reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K.
|
||||||
|
kNT_K,
|
||||||
|
// Specialization of `kNT_K` for a single K task with `kDirect`.
|
||||||
|
kNT,
|
||||||
|
|
||||||
|
// Parallelize over blocks of M and N: good when both are large. We no longer
|
||||||
|
// support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as
|
||||||
|
// fast on Zen4.
|
||||||
|
kNT_MT_K,
|
||||||
|
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`.
|
||||||
|
|
||||||
|
// Resident C (`kK_M_NT`) should be good for large K relative to M and N.
|
||||||
|
// However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are
|
||||||
|
// no kN* because we expect M (batch size) to be small relative to K and N.
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline bool IsBlock(MMOrder order) {
|
||||||
|
return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool IsOneKC(MMOrder order) {
|
||||||
|
return order == MMOrder::kNT || order == MMOrder::kNT_MT;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline const char* StringFromOrder(MMOrder order) {
|
||||||
|
switch (order) {
|
||||||
|
case MMOrder::kNT_K:
|
||||||
|
return "NT_K";
|
||||||
|
case MMOrder::kNT:
|
||||||
|
return "NT";
|
||||||
|
case MMOrder::kNT_MT_K:
|
||||||
|
return "NT_MT_K";
|
||||||
|
case MMOrder::kNT_MT:
|
||||||
|
return "NT_MT";
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// How/where to write the A2C0 result. This determines the `tag` argument to
|
||||||
|
// that function, which governs whether we call `MMStoreHorizontalSumsIntoC` or
|
||||||
|
// `MMAddHorizontalSumsIntoPartial`.
|
||||||
|
enum class MMOut : uint8_t {
|
||||||
|
kCopy, // accumulate into partial, scale/add to C
|
||||||
|
kDirect, // single kc task, write directly to C
|
||||||
|
kParM // kCopy but parallel over M
|
||||||
|
// kParN is not better on SKX/Zen4.
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline const char* StringFromOut(MMOut out) {
|
||||||
|
switch (out) {
|
||||||
|
case MMOut::kDirect:
|
||||||
|
return "Direct";
|
||||||
|
case MMOut::kCopy:
|
||||||
|
return "Copy";
|
||||||
|
case MMOut::kParM:
|
||||||
|
return "ParM";
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// How to parallelize the per-package `DecompressA`. To reduce combinatorial
|
||||||
|
// explosion, we tune this separately from `MMConfig`.
|
||||||
|
enum class MMParA : uint8_t { kNone, kK1 = 1, kK2 = 2, kK4 = 4, kM };
|
||||||
|
|
||||||
|
static inline const char* StringFromParA(MMParA par_a) {
|
||||||
|
switch (par_a) {
|
||||||
|
case MMParA::kNone:
|
||||||
|
return "ParA0 ";
|
||||||
|
case MMParA::kK1:
|
||||||
|
return "ParAK1";
|
||||||
|
case MMParA::kK2:
|
||||||
|
return "ParAK2";
|
||||||
|
case MMParA::kK4:
|
||||||
|
return "ParAK4";
|
||||||
|
case MMParA::kM:
|
||||||
|
return "ParAM ";
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Possible configurations for the autotuner to choose from:
|
||||||
|
// `mr` := C rows to write at a time (< #registers / `kNR`),
|
||||||
|
// `kc` := A / B columns such that `mr` rows fit in L1,
|
||||||
|
// `mc` := A rows such that `kc` columns fit in L2,
|
||||||
|
// `nc` := B rows such that `kc` columns fit in L3 alongside `mc x nc` C.
|
||||||
|
// Also includes loop order and task granularity.
|
||||||
|
#pragma pack(push, 1)
|
||||||
|
class MMConfig {
|
||||||
|
public:
|
||||||
|
MMConfig() = default; // for std::vector
|
||||||
|
// `mr` is the number of A rows per call to `MMKernel::LoopOverKC`.
|
||||||
|
// `MMOrder` is how to parallelize the outer loops.
|
||||||
|
// `MMOut` is how/whether to parallelize filling the C result.
|
||||||
|
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
|
||||||
|
MMConfig(size_t K, size_t mr, size_t mc, size_t kc, size_t nc,
|
||||||
|
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
|
||||||
|
int inner_tasks)
|
||||||
|
: mr_(static_cast<uint32_t>(mr)),
|
||||||
|
mc_(static_cast<uint32_t>(mc)),
|
||||||
|
kc_(static_cast<uint32_t>(kc)),
|
||||||
|
nc_(static_cast<uint32_t>(nc)),
|
||||||
|
nc_multiple_(static_cast<uint32_t>(nc_multiple)),
|
||||||
|
kc_multiple_(static_cast<uint32_t>(kc_multiple)),
|
||||||
|
order_(order),
|
||||||
|
out_(out),
|
||||||
|
inner_tasks_(static_cast<uint8_t>(inner_tasks)),
|
||||||
|
reserved_{} {
|
||||||
|
HWY_DASSERT(mr == 1 || mr == 2 || mr == 4);
|
||||||
|
if (mc % mr != 0) {
|
||||||
|
HWY_WARN("mc %zu not a multiple of mr %zu", mc, mr);
|
||||||
|
}
|
||||||
|
// Do not warn for single-kc tasks; some models unfortunately have K which
|
||||||
|
// are not multiples of `kc_multiple`.
|
||||||
|
if (kc != K && (kc % kc_multiple) != 0) {
|
||||||
|
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
|
||||||
|
}
|
||||||
|
if (nc % nc_multiple != 0) {
|
||||||
|
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
|
||||||
|
}
|
||||||
|
HWY_DASSERT(StringFromOrder(order_) != nullptr);
|
||||||
|
HWY_DASSERT(StringFromOut(out_) != nullptr);
|
||||||
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Splits M/N into blocks which are visited sequentially or in parallel.
|
||||||
|
// K is always sequential, see `MMOrder`.
|
||||||
|
IndexRangePartition RangesOfMC(size_t M) const {
|
||||||
|
return MaxSizePartition(IndexRange(0, M), mc_, mr_);
|
||||||
|
}
|
||||||
|
IndexRangePartition RangesOfKC(size_t K) const {
|
||||||
|
return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_);
|
||||||
|
}
|
||||||
|
IndexRangePartition RangesOfNC(IndexRange range_np) const {
|
||||||
|
return MaxSizePartition(range_np, nc_, nc_multiple_);
|
||||||
|
}
|
||||||
|
|
||||||
|
MMOrder Order() const { return order_; }
|
||||||
|
MMOut Out() const { return out_; }
|
||||||
|
// No `OuterTasks` because static partitioning across clusters is sufficient.
|
||||||
|
size_t InnerTasks() const { return static_cast<size_t>(inner_tasks_); }
|
||||||
|
|
||||||
|
// Accessors for printing autotune result.
|
||||||
|
size_t MR() const { return static_cast<size_t>(mr_); }
|
||||||
|
size_t MC() const { return static_cast<size_t>(mc_); }
|
||||||
|
size_t KC() const { return static_cast<size_t>(kc_); }
|
||||||
|
size_t NC() const { return static_cast<size_t>(nc_); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Somewhat-compressed representation because MMCandidates may return dozens.
|
||||||
|
uint32_t mr_;
|
||||||
|
uint32_t mc_;
|
||||||
|
uint32_t kc_;
|
||||||
|
uint32_t nc_;
|
||||||
|
uint32_t nc_multiple_;
|
||||||
|
uint32_t kc_multiple_;
|
||||||
|
MMOrder order_;
|
||||||
|
MMOut out_;
|
||||||
|
uint8_t inner_tasks_;
|
||||||
|
HWY_MAYBE_UNUSED uint8_t reserved_[5];
|
||||||
|
};
|
||||||
|
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
|
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
|
||||||
|
size_t nr,
|
||||||
|
const IndexRangePartition& ranges_np,
|
||||||
|
bool print_config);
|
||||||
|
|
||||||
|
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
|
||||||
|
// main MatMul autotuner.
|
||||||
|
template <typename TConfig>
|
||||||
|
class MMAutoTune {
|
||||||
|
public:
|
||||||
|
// Returns nullptr if not yet finished, otherwise the best config. Do not
|
||||||
|
// store this pointer because it can be invalidated.
|
||||||
|
const TConfig* Best() const { return best_; }
|
||||||
|
|
||||||
|
// If false, caller must call `SetCandidates` before `NextConfig`.
|
||||||
|
bool HasCandidates() const {
|
||||||
|
HWY_DASSERT(!Best());
|
||||||
|
return !candidates_.empty();
|
||||||
|
}
|
||||||
|
void SetCandidates(std::vector<TConfig> candidates) {
|
||||||
|
HWY_DASSERT(!HasCandidates());
|
||||||
|
candidates_.swap(candidates);
|
||||||
|
HWY_DASSERT(HasCandidates());
|
||||||
|
min_ticks_.resize(candidates_.size(), ~uint64_t{0});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the current `TConfig` to measure.
|
||||||
|
const TConfig& NextConfig() const {
|
||||||
|
HWY_DASSERT(!Best() && HasCandidates());
|
||||||
|
return candidates_[config_idx_];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the best ticks so far for this candidate. Negligible CPU time.
|
||||||
|
uint64_t NotifyTicks(uint64_t ticks) {
|
||||||
|
HWY_DASSERT(HasCandidates());
|
||||||
|
HWY_DASSERT(!skipped_.Get(config_idx_));
|
||||||
|
|
||||||
|
best_ticks_ = HWY_MIN(best_ticks_, ticks);
|
||||||
|
min_ticks_[config_idx_] = HWY_MIN(min_ticks_[config_idx_], ticks);
|
||||||
|
// Best so far. Save because we update `config_idx_` below.
|
||||||
|
const size_t my_best_ticks = min_ticks_[config_idx_];
|
||||||
|
const size_t my_idx = config_idx_;
|
||||||
|
|
||||||
|
// Advance/wrap around to next non-skipped config. Do this first because it
|
||||||
|
// updates `rounds_complete_`. To decorrelate measurements, we do not
|
||||||
|
// immediately re-measure the same config.
|
||||||
|
for (;;) {
|
||||||
|
++config_idx_;
|
||||||
|
if (HWY_UNLIKELY(config_idx_ == candidates_.size())) {
|
||||||
|
config_idx_ = 0;
|
||||||
|
++rounds_complete_;
|
||||||
|
}
|
||||||
|
// Guaranteed to terminate because `best_ticks_` is never worse than any
|
||||||
|
// other, hence is not skipped.
|
||||||
|
if (!skipped_.Get(config_idx_)) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disqualify from future `NextConfig` if the best of two measurements so
|
||||||
|
// far is sufficiently worse than `best_ticks_`. This tolerates some noise
|
||||||
|
// in the first or second measurement.
|
||||||
|
if (rounds_complete_ != 0 && my_best_ticks > 5 * best_ticks_ / 4) {
|
||||||
|
skipped_.Set(my_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// After sufficient rounds, choose the winner.
|
||||||
|
if (rounds_complete_ == 4) {
|
||||||
|
for (size_t i = 0; i < candidates_.size(); ++i) {
|
||||||
|
worst_min_ticks_ = HWY_MAX(worst_min_ticks_, min_ticks_[i]);
|
||||||
|
if (min_ticks_[i] == best_ticks_) {
|
||||||
|
// Causes `Best()` to be non-null, hence `MatMul` will no longer call
|
||||||
|
// `NextConfig` for this shape.
|
||||||
|
best_ = &candidates_[i];
|
||||||
|
config_idx_ = i; // just in case callers want to know which index.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
HWY_DASSERT(best_ != nullptr); // no min_ticks_ matches best_ticks_
|
||||||
|
}
|
||||||
|
|
||||||
|
return my_best_ticks;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid printing the first two rounds, because those might be noisy and not
|
||||||
|
// yet skipped.
|
||||||
|
bool ShouldPrint() { return rounds_complete_ > 2; }
|
||||||
|
|
||||||
|
// Only valid after Best() is non-null. Used to compute the autotuning gain.
|
||||||
|
uint64_t BestTicks() const { return best_ticks_; }
|
||||||
|
uint64_t WorstMinTicks() const { return worst_min_ticks_; }
|
||||||
|
uint64_t FirstConfigTicks() const { return min_ticks_[0]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const TConfig* best_ = nullptr;
|
||||||
|
std::vector<TConfig> candidates_;
|
||||||
|
// Use Min because threads are pinned, so we only expect additive noise.
|
||||||
|
std::vector<uint64_t> min_ticks_; // one per candidate
|
||||||
|
size_t config_idx_ = 0; // [0, candidates_.size())
|
||||||
|
size_t rounds_complete_ = 0;
|
||||||
|
uint64_t best_ticks_ = ~uint64_t{0};
|
||||||
|
uint64_t worst_min_ticks_ = 0;
|
||||||
|
hwy::BitSet4096<> skipped_;
|
||||||
|
};
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Map of previously seen dimensions to index via linear search.
|
||||||
|
class MMKeys {
|
||||||
|
public:
|
||||||
|
using Key = uint64_t;
|
||||||
|
// KeyFromDims will only return this if all dims are zero, which is invalid.
|
||||||
|
static constexpr Key kPadding = 0;
|
||||||
|
|
||||||
|
// Compresses the dimensions into a single Key for faster comparison.
|
||||||
|
static Key KeyFromDims(size_t M, size_t K, size_t N) {
|
||||||
|
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
|
||||||
|
HWY_DASSERT(K < (Key{1} << 24));
|
||||||
|
HWY_DASSERT(N < (Key{1} << 24));
|
||||||
|
const Key key = static_cast<Key>(M) | (static_cast<Key>(K) << 16) |
|
||||||
|
(static_cast<Key>(N) << 40);
|
||||||
|
HWY_DASSERT(key != kPadding);
|
||||||
|
return key;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We leave the search to callers so they can use dynamic-dispatched SIMD,
|
||||||
|
// which is not possible in this header.
|
||||||
|
hwy::Span<const Key> Keys() const {
|
||||||
|
return hwy::Span<const Key>(keys_.get(), num_unique_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must only be called if not already present in `Keys()`.
|
||||||
|
void Append(Key key) {
|
||||||
|
// Dynamic allocation because the test checks many more dimensions than
|
||||||
|
// would be reasonable to pre-allocate. DIY for alignment and padding.
|
||||||
|
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
|
||||||
|
const size_t NU64 = Allocator::VectorBytes() / sizeof(Key);
|
||||||
|
// Start at one vector so the size is always a multiple of N.
|
||||||
|
if (HWY_UNLIKELY(capacity_ == 0)) {
|
||||||
|
capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below
|
||||||
|
}
|
||||||
|
capacity_ *= 2;
|
||||||
|
HWY_DASSERT(capacity_ >= num_unique_ + 1);
|
||||||
|
hwy::AlignedFreeUniquePtr<Key[]> new_keys =
|
||||||
|
hwy::AllocateAligned<Key>(capacity_);
|
||||||
|
hwy::CopyBytes(keys_.get(), new_keys.get(), num_unique_ * sizeof(Key));
|
||||||
|
// Pad for SIMD.
|
||||||
|
for (size_t i = num_unique_; i < hwy::RoundUpTo(num_unique_, NU64); ++i) {
|
||||||
|
new_keys[i] = kPadding;
|
||||||
|
}
|
||||||
|
keys_.swap(new_keys);
|
||||||
|
}
|
||||||
|
keys_[num_unique_++] = key;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t capacity_ = 0;
|
||||||
|
size_t num_unique_ = 0;
|
||||||
|
hwy::AlignedFreeUniquePtr<Key[]> keys_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Per-MatMul-shape state.
|
||||||
|
struct MMPerKey {
|
||||||
|
MMPerKey(size_t max_packages, size_t N, size_t nr, MMParallel& parallel)
|
||||||
|
: ranges_np(parallel.RangesOfNP(max_packages, N, nr)) {}
|
||||||
|
|
||||||
|
// Only profile if enabled and the main autotuner finished (the par_a
|
||||||
|
// autotuner is per-package and we want to avoid synchronization).
|
||||||
|
bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); }
|
||||||
|
|
||||||
|
const IndexRangePartition ranges_np;
|
||||||
|
MMAutoTune<MMConfig> autotune;
|
||||||
|
MMAutoTune<MMParA> autotune_par_a[MMParallel::kMaxPackages];
|
||||||
|
};
|
||||||
|
|
||||||
|
// Stores state shared across MatMul calls. Non-copyable.
|
||||||
|
struct MatMulEnv {
|
||||||
|
explicit MatMulEnv(NestedPools& pools);
|
||||||
|
|
||||||
|
bool have_timer_stop = false;
|
||||||
|
|
||||||
|
// Enable binding: disabled in Gemma until tensors support it, enabled in
|
||||||
|
// bench_matmul.cc.
|
||||||
|
bool enable_bind = false;
|
||||||
|
|
||||||
|
// Whether `MMCandidates()` should print the set of parameters.
|
||||||
|
bool print_config = false;
|
||||||
|
// Whether to print each config's speed during autotuning.
|
||||||
|
bool print_measurement = false;
|
||||||
|
// Whether to print the best config immediately after autotuning finished.
|
||||||
|
bool print_best = false;
|
||||||
|
|
||||||
|
MMParallel parallel;
|
||||||
|
MMStorage storage;
|
||||||
|
MMKeys keys;
|
||||||
|
std::vector<MMPerKey> per_key;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Arguments to MatMul() that are independent of the A/B type.
|
||||||
|
// Reduces register pressure compared to individual values/references.
|
||||||
|
struct MMArgs {
|
||||||
|
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
|
||||||
|
const float* HWY_RESTRICT add, const RowPtrD& partial,
|
||||||
|
const RowPtrF& C)
|
||||||
|
: env(&env),
|
||||||
|
per_key(&per_key),
|
||||||
|
scale(scale),
|
||||||
|
add(add),
|
||||||
|
partial(partial),
|
||||||
|
C(C) {}
|
||||||
|
|
||||||
|
MatMulEnv* env;
|
||||||
|
MMPerKey* per_key;
|
||||||
|
|
||||||
|
double scale;
|
||||||
|
const float* HWY_RESTRICT add;
|
||||||
|
// Same size as C, threads write at false-sharing-free granularity.
|
||||||
|
RowPtrD partial;
|
||||||
|
RowPtrF C;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||||
|
#if PROFILER_ENABLED
|
||||||
|
class MMZone {
|
||||||
|
using Zone = hwy::Zone;
|
||||||
|
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8);
|
||||||
|
|
||||||
|
public:
|
||||||
|
~MMZone() {
|
||||||
|
if (used_) {
|
||||||
|
Zone* zone = reinterpret_cast<Zone*>(&data_);
|
||||||
|
zone->~Zone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `name` must be a string literal.
|
||||||
|
void MaybeEnter(const char* name, const MMArgs& args) {
|
||||||
|
if (args.per_key->WantProfile()) {
|
||||||
|
new (&data_) Zone(name);
|
||||||
|
used_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64_t data_ = 0;
|
||||||
|
bool used_ = false;
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
struct MMZone {
|
||||||
|
void MaybeEnter(const char*, const MMArgs&) {}
|
||||||
|
};
|
||||||
|
#endif // PROFILER_ENABLED
|
||||||
|
|
||||||
// Used for the A and B arguments of `MatMul`, which are always const.
|
// Used for the A and B arguments of `MatMul`, which are always const.
|
||||||
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
|
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
|
||||||
// `ofs` required for compressed T.
|
// `ofs` required for compressed T.
|
||||||
|
|
@ -161,6 +742,29 @@ ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename TB>
|
||||||
|
void BindB(size_t N, const ConstMat<TB>& B, MMParallel& parallel) {
|
||||||
|
if (!Allocator::ShouldBind()) return;
|
||||||
|
|
||||||
|
const IndexRangePartition ranges_np =
|
||||||
|
parallel.RangesOfNP(MMParallel::kMaxPackages, N, kNR);
|
||||||
|
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
|
||||||
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
|
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||||
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
|
uintptr_t begin =
|
||||||
|
reinterpret_cast<uintptr_t>(B.ptr + B.Row(rows_b.begin()));
|
||||||
|
uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB);
|
||||||
|
// B is not yet guaranteed to have padded rows, so only bind the
|
||||||
|
// subset that is page-aligned.
|
||||||
|
begin = hwy::RoundUpTo(begin, quantum);
|
||||||
|
end = hwy::RoundDownTo(end, quantum);
|
||||||
|
if (HWY_LIKELY(begin != end)) {
|
||||||
|
Allocator::BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||||
|
|
|
||||||
|
|
@ -243,17 +243,20 @@ template <typename TA, typename TB = TA>
|
||||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
|
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
|
||||||
fprintf(stderr, "TestMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac,
|
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac,
|
||||||
cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>());
|
cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>());
|
||||||
|
|
||||||
|
env.print_config = true;
|
||||||
|
env.print_best = true;
|
||||||
|
|
||||||
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
||||||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||||
const Extents2D C_extents(rows_ac, cols_bc);
|
const Extents2D C_extents(rows_ac, cols_bc);
|
||||||
|
|
||||||
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||||
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||||
RowVectorBatch<float> c_slow_batch(C_extents);
|
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
|
||||||
RowVectorBatch<float> c_batch(C_extents);
|
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
|
||||||
HWY_ASSERT(a && b_trans);
|
HWY_ASSERT(a && b_trans);
|
||||||
|
|
||||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||||
|
|
@ -270,8 +273,12 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
const RowPtrF C = RowPtrFromBatch(c_batch);
|
||||||
|
|
||||||
MatMulSlow(A, B, add_row, env, C_slow);
|
MatMulSlow(A, B, add_row, env, C_slow);
|
||||||
MatMul(A, B, add_row, env, C);
|
// A few reps to get coverage of the various autotuned code paths.
|
||||||
AssertClose(A, B, C_slow, C);
|
for (size_t rep = 0; rep < 16; ++rep) {
|
||||||
|
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
|
||||||
|
AssertClose(A, B, C_slow, C);
|
||||||
|
if (per_key->autotune.Best()) break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
using F32 = float;
|
using F32 = float;
|
||||||
|
|
@ -298,13 +305,12 @@ void TestTiny() {
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
pools.MaybeStartSpinning(use_spinning);
|
pools.MaybeStartSpinning(use_spinning);
|
||||||
Allocator::Init(pools.Topology());
|
Allocator::Init(pools.Topology(), /*enable_bind=*/true);
|
||||||
MatMulEnv env(pools);
|
MatMulEnv env(pools);
|
||||||
|
|
||||||
for (size_t M = 1; M <= 3 * kRegRows; ++M) {
|
for (size_t M = 1; M <= 12; ++M) {
|
||||||
for (size_t K = 64; K <= 128; K *= 2) {
|
for (size_t K = 1; K <= 64; K *= 2) {
|
||||||
for (size_t N = /*kRegRows*/ 16; N <= 64;
|
for (size_t N = 4; N <= 64; N += max_packages * 4) {
|
||||||
N += max_packages * kRegRows) {
|
|
||||||
TestMatMul<F32, F32>(M, K, N, /*add=*/false, env);
|
TestMatMul<F32, F32>(M, K, N, /*add=*/false, env);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -323,7 +329,7 @@ void TestAllMatMul() {
|
||||||
NestedPools pools(0); // no limits
|
NestedPools pools(0); // no limits
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
pools.MaybeStartSpinning(use_spinning);
|
pools.MaybeStartSpinning(use_spinning);
|
||||||
Allocator::Init(pools.Topology());
|
Allocator::Init(pools.Topology(), /*enable_bind=*/true);
|
||||||
MatMulEnv env(pools);
|
MatMulEnv env(pools);
|
||||||
|
|
||||||
// Sizes seen in gemma_test 2B.
|
// Sizes seen in gemma_test 2B.
|
||||||
|
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
// Copyright 2023 Google LLC
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// TODO: Tests of individual MatMul components.
|
|
||||||
int main() { return 0; }
|
|
||||||
|
|
@ -35,6 +35,7 @@
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
#include "util/app.h"
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
@ -386,6 +387,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRopeAndMulBy() {
|
void TestRopeAndMulBy() {
|
||||||
|
NestedPools pools = CreatePools(AppArgs());
|
||||||
|
Allocator::Init(pools.Topology());
|
||||||
|
|
||||||
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
||||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||||
RowVectorBatch<float> x(Extents2D(1, dim_qkv));
|
RowVectorBatch<float> x(Extents2D(1, dim_qkv));
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,11 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef GEMMA_BIND // allow override
|
#ifndef GEMMA_BIND // allow override
|
||||||
|
// OSes will generally do the right thing when threads allocate their own
|
||||||
|
// working memory. However, matmul's B and C matrices are preferably sharded
|
||||||
|
// across NUMA nodes. To simplify the matrix representation, we prefer a
|
||||||
|
// single allocation. This requires page-level control over the memory layout,
|
||||||
|
// which Linux provides via `move_pages`, but Windows does not.
|
||||||
#if defined(GEMMA_LINUX_SYSCALL6) && !defined(__ANDROID_API__)
|
#if defined(GEMMA_LINUX_SYSCALL6) && !defined(__ANDROID_API__)
|
||||||
#define GEMMA_BIND 1
|
#define GEMMA_BIND 1
|
||||||
#else
|
#else
|
||||||
|
|
@ -93,7 +98,7 @@ size_t Allocator::L2Bytes() { return l2_bytes_; }
|
||||||
size_t Allocator::L3Bytes() { return l3_bytes_; }
|
size_t Allocator::L3Bytes() { return l3_bytes_; }
|
||||||
bool Allocator::ShouldBind() { return should_bind_; }
|
bool Allocator::ShouldBind() { return should_bind_; }
|
||||||
|
|
||||||
void Allocator::Init(const BoundedTopology& topology) {
|
void Allocator::Init(const BoundedTopology& topology, bool enable_bind) {
|
||||||
line_bytes_ = DetectLineBytes();
|
line_bytes_ = DetectLineBytes();
|
||||||
vector_bytes_ = hwy::VectorBytes();
|
vector_bytes_ = hwy::VectorBytes();
|
||||||
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
|
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
|
||||||
|
|
@ -122,13 +127,19 @@ void Allocator::Init(const BoundedTopology& topology) {
|
||||||
const size_t page_bytes = DetectPageSize();
|
const size_t page_bytes = DetectPageSize();
|
||||||
if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
|
if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
|
||||||
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
|
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
|
||||||
// Ensure pages meet the alignment requirements of `AllocBytes`.
|
if (enable_bind) {
|
||||||
HWY_ASSERT(page_bytes >= quantum_bytes_);
|
// Ensure pages meet the alignment requirements of `AllocBytes`.
|
||||||
quantum_bytes_ = page_bytes;
|
HWY_ASSERT(page_bytes >= quantum_bytes_);
|
||||||
// Ensure MaxQuantumBytes() is an upper bound.
|
quantum_bytes_ = page_bytes;
|
||||||
HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_);
|
// Ensure MaxQuantumBytes() is an upper bound.
|
||||||
quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes());
|
HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_);
|
||||||
should_bind_ = true;
|
quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes());
|
||||||
|
should_bind_ = true;
|
||||||
|
} else {
|
||||||
|
HWY_WARN(
|
||||||
|
"Multiple sockets but binding disabled. This reduces speed; "
|
||||||
|
"set or remove enable_bind to avoid this warning.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
|
||||||
|
|
||||||
|
// Allocator with support for sharding tensors across NUMA nodes.
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
|
@ -65,7 +67,8 @@ class Allocator {
|
||||||
public:
|
public:
|
||||||
// Must be called at least once before any other function. Not thread-safe,
|
// Must be called at least once before any other function. Not thread-safe,
|
||||||
// hence only call this from the main thread.
|
// hence only call this from the main thread.
|
||||||
static void Init(const BoundedTopology& topology);
|
// TODO: remove enable_bind once Gemma tensors support binding.
|
||||||
|
static void Init(const BoundedTopology& topology, bool enable_bind = false);
|
||||||
|
|
||||||
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
||||||
// ranges such that there will be no false sharing.
|
// ranges such that there will be no false sharing.
|
||||||
|
|
@ -80,8 +83,10 @@ class Allocator {
|
||||||
static constexpr size_t MaxQuantumBytes() { return 4096; }
|
static constexpr size_t MaxQuantumBytes() { return 4096; }
|
||||||
static size_t QuantumSteps(); // = QuantumBytes() / StepBytes()
|
static size_t QuantumSteps(); // = QuantumBytes() / StepBytes()
|
||||||
|
|
||||||
|
// L1 and L2 are typically per core.
|
||||||
static size_t L1Bytes();
|
static size_t L1Bytes();
|
||||||
static size_t L2Bytes();
|
static size_t L2Bytes();
|
||||||
|
// Clusters often share an L3. We return the total size per package.
|
||||||
static size_t L3Bytes();
|
static size_t L3Bytes();
|
||||||
|
|
||||||
// Returns pointer aligned to `QuantumBytes()`.
|
// Returns pointer aligned to `QuantumBytes()`.
|
||||||
|
|
@ -119,6 +124,19 @@ class Allocator {
|
||||||
static PtrAndDeleter AllocBytes(size_t bytes);
|
static PtrAndDeleter AllocBytes(size_t bytes);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Value of `stride` to pass to `RowVectorBatch` to enable the "cyclic offsets"
|
||||||
|
// optimization. If `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is
|
||||||
|
// typically 4KiB. To avoid remote accesses, we would thus pad each row to that,
|
||||||
|
// which results in 4K aliasing and/or cache conflict misses. `RowPtr` is able
|
||||||
|
// to prevent that by pulling rows forward by a cyclic offset, which is still a
|
||||||
|
// multiple of the cache line size. This requires an additional
|
||||||
|
// `Allocator::QuantumBytes()` of padding after also rounding up to that.
|
||||||
|
template <typename T>
|
||||||
|
constexpr size_t StrideForCyclicOffsets(size_t cols) {
|
||||||
|
const size_t quantum = Allocator::MaxQuantumBytes() / sizeof(T);
|
||||||
|
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||||
|
}
|
||||||
|
|
||||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||||
// the memory.
|
// the memory.
|
||||||
|
|
@ -130,6 +148,7 @@ class RowVectorBatch {
|
||||||
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
||||||
// we default to tightly packed rows (`stride = cols`).
|
// we default to tightly packed rows (`stride = cols`).
|
||||||
// WARNING: not all call sites support `stride` != cols.
|
// WARNING: not all call sites support `stride` != cols.
|
||||||
|
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
|
||||||
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
|
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
|
||||||
if (stride == 0) {
|
if (stride == 0) {
|
||||||
stride_ = extents_.cols;
|
stride_ = extents_.cols;
|
||||||
|
|
@ -137,7 +156,10 @@ class RowVectorBatch {
|
||||||
HWY_ASSERT(stride >= extents_.cols);
|
HWY_ASSERT(stride >= extents_.cols);
|
||||||
stride_ = stride;
|
stride_ = stride;
|
||||||
}
|
}
|
||||||
mem_ = Allocator::Alloc<T>(extents_.rows * stride_);
|
// Allow binding the entire matrix.
|
||||||
|
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
|
||||||
|
Allocator::QuantumBytes() / sizeof(T));
|
||||||
|
mem_ = Allocator::Alloc<T>(padded);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move-only
|
// Move-only
|
||||||
|
|
@ -186,6 +208,11 @@ static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
|
||||||
return padded_num;
|
return padded_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
RowVectorBatch<T> AllocateAlignedRows(Extents2D extents) {
|
||||||
|
return RowVectorBatch<T>(extents, StrideForCyclicOffsets<T>(extents.cols));
|
||||||
|
}
|
||||||
|
|
||||||
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
||||||
// it is always float and does not support compressed T, but does support an
|
// it is always float and does not support compressed T, but does support an
|
||||||
// arbitrary stride >= cols.
|
// arbitrary stride >= cols.
|
||||||
|
|
@ -202,7 +229,19 @@ class RowPtr {
|
||||||
row_mask_(Allocator::QuantumSteps() - 1) {
|
row_mask_(Allocator::QuantumSteps() - 1) {
|
||||||
HWY_DASSERT(stride >= cols);
|
HWY_DASSERT(stride >= cols);
|
||||||
HWY_DASSERT(row_mask_ != ~size_t{0});
|
HWY_DASSERT(row_mask_ != ~size_t{0});
|
||||||
row_mask_ = 0; // TODO: remove
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
if (stride < StrideForCyclicOffsets<T>(cols)) {
|
||||||
|
static bool once;
|
||||||
|
if (!once) {
|
||||||
|
once = true;
|
||||||
|
HWY_WARN(
|
||||||
|
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
|
||||||
|
"T=%zu; this forces us to disable cyclic offsets.",
|
||||||
|
stride, cols, sizeof(T));
|
||||||
|
}
|
||||||
|
row_mask_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
|
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
|
||||||
|
|
||||||
|
|
|
||||||
13
util/app.h
13
util/app.h
|
|
@ -295,6 +295,19 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
runtime_config.max_generated_tokens = max_generated_tokens;
|
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||||
|
if (prefill_tbatch_size > MMStorage::kMaxM) {
|
||||||
|
HWY_ABORT(
|
||||||
|
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||||
|
"or increase the constant in MMStorage.\n",
|
||||||
|
prefill_tbatch_size, MMStorage::kMaxM);
|
||||||
|
}
|
||||||
|
if (decode_qbatch_size > MMStorage::kMaxM) {
|
||||||
|
HWY_ABORT(
|
||||||
|
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||||
|
"or increase the constant in MMStorage.\n",
|
||||||
|
decode_qbatch_size, MMStorage::kMaxM);
|
||||||
|
}
|
||||||
|
|
||||||
runtime_config.temperature = temperature;
|
runtime_config.temperature = temperature;
|
||||||
runtime_config.top_k = top_k;
|
runtime_config.top_k = top_k;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue