Matmul rewrite: fp64 sums, hierarchical parallelization, cache-blocking, autotuning

Remove empty matmul_unit_test.
Up to 25 TFLOP/s on 2xZen4 for 512,3072,24576.

PiperOrigin-RevId: 729123576
This commit is contained in:
Jan Wassenberg 2025-02-20 08:32:52 -08:00 committed by Copybara-Service
parent d854471ae2
commit f9d93e4a42
16 changed files with 2527 additions and 591 deletions

View File

@ -85,6 +85,9 @@ test_suite(
cc_library( cc_library(
name = "ops", name = "ops",
srcs = [
"ops/matmul.cc",
],
hdrs = [ hdrs = [
"ops/matmul.h", "ops/matmul.h",
"ops/ops.h", "ops/ops.h",
@ -103,11 +106,14 @@ cc_library(
":threading", ":threading",
"//compression:compress", "//compression:compress",
"@highway//:algo", "@highway//:algo",
"@highway//:bit_set",
"@highway//:hwy", "@highway//:hwy",
"@highway//:math", "@highway//:math",
"@highway//:matvec", "@highway//:matvec",
"@highway//:nanobenchmark",
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",
"@highway//:topology",
"@highway//hwy/contrib/sort:vqsort", "@highway//hwy/contrib/sort:vqsort",
], ],
) )
@ -126,6 +132,7 @@ cc_test(
":test_util", ":test_util",
":threading", ":threading",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:app",
"//compression:compress", "//compression:compress",
"//compression:test_util", "//compression:test_util",
"@highway//:hwy", "@highway//:hwy",
@ -151,6 +158,7 @@ cc_test(
":ops", ":ops",
":test_util", ":test_util",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//:app",
"//compression:compress", "//compression:compress",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
@ -176,26 +184,6 @@ cc_test(
], ],
) )
cc_test(
name = "matmul_unit_test",
size = "small",
timeout = "long",
srcs = ["ops/matmul_unit_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
":ops",
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
cc_test( cc_test(
name = "matmul_test", name = "matmul_test",
size = "small", size = "small",
@ -652,6 +640,7 @@ cc_test(
":sampler", ":sampler",
":weights", ":weights",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//:threading",
"//compression:compress", "//compression:compress",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",

View File

@ -92,6 +92,8 @@ set(SOURCES
gemma/weights.h gemma/weights.h
ops/dot-inl.h ops/dot-inl.h
ops/matmul-inl.h ops/matmul-inl.h
ops/matmul.cc
ops/matmul.h
ops/matvec-inl.h ops/matvec-inl.h
ops/ops-inl.h ops/ops-inl.h
ops/ops.h ops/ops.h
@ -168,7 +170,6 @@ set(GEMMA_TEST_FILES
ops/dot_test.cc ops/dot_test.cc
ops/gemma_matvec_test.cc ops/gemma_matvec_test.cc
ops/matmul_test.cc ops/matmul_test.cc
ops/matmul_unit_test.cc
ops/ops_test.cc ops/ops_test.cc
paligemma/image_test.cc paligemma/image_test.cc
paligemma/paligemma_test.cc paligemma/paligemma_test.cc

View File

@ -33,6 +33,7 @@
#include "backprop/test_util.h" #include "backprop/test_util.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "ops/ops.h" #include "ops/ops.h"
#include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -58,7 +59,9 @@ void TestMatMulVJP() {
static const size_t kRows = 8; static const size_t kRows = 8;
static const size_t kCols = 64; static const size_t kCols = 64;
static const size_t kTokens = 5; static const size_t kTokens = 5;
hwy::ThreadPool pool(8); gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
std::mt19937 gen(42); std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols); MatStorageT<float> weights("weights", kRows, kCols);
MatStorageT<float> x("x", kTokens, kCols); MatStorageT<float> x("x", kTokens, kCols);
@ -85,7 +88,7 @@ void TestMatMulVJP() {
grad.ZeroInit(); grad.ZeroInit();
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens, MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
grad.data(), dx.data(), pool); grad.data(), dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
@ -102,7 +105,9 @@ void TestMultiHeadMatMulVJP() {
static const size_t kCols = 16; static const size_t kCols = 16;
static const size_t kHeads = 4; static const size_t kHeads = 4;
static const size_t kTokens = 3; static const size_t kTokens = 3;
hwy::ThreadPool pool(8); gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
std::mt19937 gen(42); std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols * kHeads); MatStorageT<float> weights("weights", kRows, kCols * kHeads);
MatStorageT<float> x("x", kTokens, kCols * kHeads); MatStorageT<float> x("x", kTokens, kCols * kHeads);
@ -130,7 +135,7 @@ void TestMultiHeadMatMulVJP() {
grad.ZeroInit(); grad.ZeroInit();
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols, MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
kRows, kTokens, grad.data(), dx.data(), pool); kRows, kTokens, grad.data(), dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
@ -145,7 +150,9 @@ void TestMultiHeadMatMulVJP() {
void TestRMSNormVJP() { void TestRMSNormVJP() {
static const size_t K = 2; static const size_t K = 2;
static const size_t N = 64; static const size_t N = 64;
hwy::ThreadPool pool(8); gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
std::mt19937 gen(42); std::mt19937 gen(42);
MatStorageT<float> weights("weights", N, 1); MatStorageT<float> weights("weights", N, 1);
MatStorageT<float> x("x", K, N); MatStorageT<float> x("x", K, N);
@ -172,7 +179,7 @@ void TestRMSNormVJP() {
grad.ZeroInit(); grad.ZeroInit();
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
dx.data(), pool); dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
@ -209,7 +216,9 @@ static ModelConfig TestConfig() {
void TestEndToEnd() { void TestEndToEnd() {
std::mt19937 gen(42); std::mt19937 gen(42);
hwy::ThreadPool pool(0); gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 1));
Allocator::Init(pools.Topology());
ModelConfig config = TestConfig(); ModelConfig config = TestConfig();
WeightsWrapper<float> weights(config); WeightsWrapper<float> weights(config);
WeightsWrapper<float> grad(config); WeightsWrapper<float> grad(config);
@ -234,13 +243,13 @@ void TestEndToEnd() {
float loss1 = CrossEntropyLossForwardPass( float loss1 = CrossEntropyLossForwardPass(
prompt.tokens, prompt.context_size, weights.get(), forward1, prompt.tokens, prompt.context_size, weights.get(), forward1,
inv_timescale, pool); inv_timescale, pools.Pool());
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.ZeroInit(); grad.ZeroInit();
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(), CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
backward, inv_timescale, pool); backward, inv_timescale, pools.Pool());
Complexify(weights.get(), c_weights.get()); Complexify(weights.get(), c_weights.get());
auto func = [&]() { auto func = [&]() {

View File

@ -252,6 +252,10 @@ class Gemma {
void GenerateImageTokens(const RuntimeConfig& runtime_config, void GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens); const Image& image, ImageTokens& image_tokens);
void SetMatMulVerbosity(int verbosity) {
if (verbosity >= 2) env_.print_best = true;
}
private: private:
MatMulEnv env_; MatMulEnv env_;

View File

@ -225,7 +225,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
} }
if (end_of_turn_seen && abs_pos > 0) { if (end_of_turn_seen && abs_pos > 0) {
// If we have seen an end_of_turn token, we need to rewind abs_pos by one // If we have seen an end_of_turn token, we need to rewind abs_pos by one
// more, because we will pre-pend it again to the prompt in // more, because we will prepend it again to the prompt in
// WrapAndTokenize. // WrapAndTokenize.
abs_pos--; abs_pos--;
} }
@ -236,14 +236,13 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
// TODO: remove once MatMul is updated.
app.max_packages = 1;
// Note that num_threads is an upper bound; we also limit to the number of // Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores. // detected and enabled cores.
NestedPools pools = CreatePools(app); NestedPools pools = CreatePools(app);
Allocator::Init(pools.Topology()); Allocator::Init(pools.Topology());
Gemma model = CreateGemma(loader, pools); Gemma model = CreateGemma(loader, pools);
model.SetMatMulVerbosity(app.verbosity);
KVCache kv_cache = KVCache kv_cache =
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);

View File

@ -117,17 +117,18 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
} }
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
std::vector<double>& times) { std::vector<double>& times, MMPerKey* per_key) {
std::sort(times.begin(), times.end()); std::sort(times.begin(), times.end());
// bench_dnn reports the best and average, but the median seems more // bench_dnn reports the best and average, but the median seems more
// consistent and resistant to outliers. // consistent and resistant to outliers.
const double elapsed = times[times.size() / 2]; const double elapsed = times[times.size() / 2];
const double ratio = elapsed / (times[0] + 1E-6); // vs best, avoid / 0 const double vs_best = elapsed / (times[0] + 1E-6); // avoid / 0
const size_t num_b = B_extents.Area(); const size_t num_b = B_extents.Area();
// FMA counts as two FLOP. const double flops = 2 * A_extents.rows * num_b / elapsed; // FMA = 2 ops
fprintf(stderr, "%.1f\t(med %.3f ms = %0.2fx min)\n",
2 * 1E-9 * A_extents.rows * num_b / elapsed, elapsed * 1E3, ratio); fprintf(stderr, "\t%.1f GFLOPS %.3f ms %0.2fx\n", flops * 1E-9, elapsed * 1E3,
vs_best);
} }
// Generates inputs and prints observed throughput of MatMul. // Generates inputs and prints observed throughput of MatMul.
@ -135,15 +136,18 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
template <typename MatTA, typename MatTB = MatTA> template <typename MatTA, typename MatTB = MatTA>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0); hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
fprintf(stderr, "\nBenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", if (env.print_config || env.print_measurement) {
M, K, N, add, TypeName<MatTA>(), TypeName<MatTB>()); fprintf(stderr, "\n");
}
fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N,
add, TypeName<MatTA>(), TypeName<MatTB>());
const Extents2D A_extents(M, K); const Extents2D A_extents(M, K);
const Extents2D B_extents(N, K); // already transposed const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N); const Extents2D C_extents(M, N);
RowVectorBatch<float> c_slow_batch(C_extents); RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
RowVectorBatch<float> c_batch(C_extents); RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
std::unique_ptr<MatStorageT<float>> add_storage; std::unique_ptr<MatStorageT<float>> add_storage;
if (add) { if (add) {
@ -161,27 +165,40 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const float* add_row = add ? add_storage->data_scale1() : nullptr; const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C = RowPtrFromBatch(c_batch); const RowPtrF C = RowPtrFromBatch(c_batch);
constexpr size_t kSamples = 20; // Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12;
std::vector<double> times; std::vector<double> times;
times.reserve(kSamples); times.reserve(num_samples);
// Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op.
BindB(B_extents.rows, B, env.parallel);
BindC(A_extents.rows, C, env.parallel);
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;
env.parallel.Pools().MaybeStartSpinning(use_spinning); env.parallel.Pools().MaybeStartSpinning(use_spinning);
// env.print_config = true;
// env.print_measurement = true;
env.print_best = true;
double keep = 0.0; double keep = 0.0;
MMPerKey* per_key;
// Until enough samples collected *after* autotuning finished: // Until enough samples collected *after* autotuning finished:
while (times.size() < kSamples) { while (times.size() < num_samples) {
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
MatMul(A, B, add_row, env, C); per_key = MatMul(A, B, add_row, env, C);
const double t1 = hwy::platform::Now(); const double t1 = hwy::platform::Now();
double elapsed = t1 - t0; double elapsed = t1 - t0;
keep += C.Row(0)[hwy::Unpredictable1()]; keep += C.Row(0)[hwy::Unpredictable1()];
times.push_back(elapsed); // Only record times after autotuning finished.
if (per_key->autotune.Best()) times.push_back(elapsed);
} }
hwy::PreventElision(keep); hwy::PreventElision(keep);
env.parallel.Pools().MaybeStopSpinning(use_spinning); env.parallel.Pools().MaybeStopSpinning(use_spinning);
PrintSpeed(A_extents, B_extents, times); PrintSpeed(A_extents, B_extents, times, per_key);
} }
using F32 = float; using F32 = float;
@ -189,22 +206,25 @@ using SFP = SfpStream;
void BenchAllMatMul() { void BenchAllMatMul() {
if (first_target == 0) first_target = HWY_TARGET; if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return; // Disable the best-target-only limitation.
// if (HWY_TARGET != first_target) return;
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||
HWY_TARGET == HWY_SSE2) {
return;
}
for (size_t max_packages : {/*1,*/ 2}) {
const size_t max_threads = 0; // no limit const size_t max_threads = 0; // no limit
NestedPools pools(max_threads, Tristate::kDefault, const BoundedSlice package_slice; // all packages/sockets
BoundedSlice(0, max_packages)); const BoundedSlice cluster_slice; // all clusters/CCX
#if GEMMA_DISABLE_TOPOLOGY const BoundedSlice lp_slice; // default to all cores (per package).
if (max_packages == 2) break; // we only have one package NestedPools pools(max_threads, Tristate::kDefault, package_slice,
#else cluster_slice, lp_slice);
// If less than the limit, we have already tested all num_packages. fprintf(stderr, "BenchAllMatMul %s %s\n", pools.TopologyString(),
if (pools.Topology().FullTopology().packages.size() < max_packages) break; pools.PinString());
#endif
fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages,
pools.TopologyString(), pools.PinString());
Allocator::Init(pools.Topology()); Allocator::Init(pools.Topology(), /*enable_bind=*/true);
MatMulEnv env(pools); MatMulEnv env(pools);
for (size_t batch_size : {1, 4, 128, 512}) { for (size_t batch_size : {1, 4, 128, 512}) {
@ -212,7 +232,6 @@ void BenchAllMatMul() {
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env); BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env); BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
} }
}
PROFILER_PRINT_RESULTS(); PROFILER_PRINT_RESULTS();
} }

View File

@ -28,6 +28,7 @@
#include "compression/shared.h" #include "compression/shared.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/app.h"
#include "util/test_util.h" #include "util/test_util.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -999,6 +1000,8 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot const hn::ScalableTag<float> df; // for CallDot
NestedPools pools = CreatePools(AppArgs());
Allocator::Init(pools.Topology());
CompressWorkingSet work; CompressWorkingSet work;
std::mt19937 rng; std::mt19937 rng;
rng.seed(12345); rng.seed(12345);

File diff suppressed because it is too large Load Diff

415
ops/matmul.cc Normal file
View File

@ -0,0 +1,415 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ops/matmul.h"
// Analytical model of cache parameters for generating autotune candidates.
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/detect_targets.h"
#include "hwy/per_target.h"
#include "hwy/timer.h"
namespace gcpp {
namespace {
// Rounds down to a multiple of `multiple`, but returns at least `multiple`.
size_t RoundDownWithFloor(size_t value, size_t multiple) {
HWY_DASSERT(multiple != 0);
return HWY_MAX(multiple, hwy::RoundDownTo(value, multiple));
}
// Returns the highest number in `[begin, end)` that divides `dim` and is a
// multiple of `multiple`, or 0 if none exists.
size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
const size_t multiple) {
HWY_DASSERT(end != 0 && dim != 0 && multiple != 0);
size_t prev = RoundDownWithFloor(end, multiple);
// Avoid returning `end` if rounding down had no effect.
if (prev == end) prev -= multiple;
for (;;) {
if (prev == 0) return 0; // No divisor if large multiple or small end.
if (dim % prev == 0) return prev;
if (prev <= begin) return 0;
prev -= multiple;
}
}
// Implementation of `MMCandidates`. Class hides the `KC` etc member functions
// and holds most of their arguments in member variables.
class GenerateCandidates {
public:
GenerateCandidates(size_t M, size_t K, size_t N, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config)
: M_(M),
K_(K),
max_mr_(max_mr),
nr_(nr),
// These influence kc/nc, but are also stored in `MMConfig` for
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
// is likely still in L1, but we expect K > 1000 and might as well round
// up to the line size.
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
nc_multiple_(Allocator::StepBytes() / sizeof(float)),
ranges_np_(ranges_np),
print_config_(print_config) {}
std::vector<MMConfig> operator()() const {
std::vector<MMConfig> candidates;
candidates.reserve(128);
for (size_t mr : MR()) {
for (MMOrder order : Orders(mr)) {
const std::vector<int>& all_inner_tasks = InnerTasks(order);
const std::vector<MMOut>& all_outs = Outs(order);
for (size_t kc : KC(mr, order)) {
for (size_t mc : MC(mr, kc, order)) {
for (size_t nc : NC(mr, mc, kc, order)) {
for (int inner_tasks : all_inner_tasks) {
for (MMOut out : all_outs) {
const MMConfig config(K_, mr, mc, kc, nc, kc_multiple_,
nc_multiple_, order, out, inner_tasks);
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
// Blocks only make sense when there are multiple M tasks.
if (IsBlock(order) != (M_tasks > 1)) continue;
// Single KC only makes sense when there is a single K task.
if (IsOneKC(order) != (K_tasks == 1)) continue;
candidates.push_back(config);
}
}
}
}
}
}
}
HWY_ASSERT(!candidates.empty());
return candidates;
}
private:
using SizeVec = std::vector<size_t>;
// How many rows of A per call to `MMKernel::LoopOverKC`. Lower values may
// be better for SIMD targets with fewer registers.
SizeVec MR() const {
const int64_t target = hwy::DispatchedTarget();
const bool is_avx2 = target == HWY_AVX2;
const bool is_sse = HWY_SSE4 <= target && target <= HWY_SSE2;
const bool is_wasm = target == HWY_WASM || target == HWY_WASM_EMU256;
SizeVec all_mr;
all_mr.reserve(3);
// AVX2's 16 registers are not enough for four rows, but SSE4 may benefit.
if (M_ >= max_mr_ && !is_avx2) all_mr.push_back(max_mr_);
// Allow for AVX-512 but not SSE4 (for which 4 are usually better). Also
// enable if not enough rows for 4.
if (M_ >= 2 && (M_ < max_mr_ || (!is_sse && !is_wasm))) {
all_mr.push_back(size_t{2});
}
// Even SSE4 usually prefers 2 rows; only enable for single rows.
if (M_ == 1) all_mr.push_back(size_t{1});
HWY_ASSERT(!all_mr.empty());
return all_mr;
}
// Which loop orders to enable depending on M.
std::vector<MMOrder> Orders(size_t mr) const {
std::vector<MMOrder> orders;
for (size_t order_idx = 0;; ++order_idx) {
const MMOrder order = static_cast<MMOrder>(order_idx);
if (StringFromOrder(order) == nullptr) return orders; // done
// 2D blocking is useless for a single row of M.
if (IsBlock(order) && M_ <= mr) continue;
// Conversely, N-only parallelism is uncompetitive for large M.
if (!IsBlock(order) && M_ >= 8 * mr) continue;
orders.push_back(order);
}
}
// The number of A and B columns to read between updating `partial`.
SizeVec KC(size_t mr, MMOrder order) const {
// `LoopOverKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(M_, mr);
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
// `buf`. To amortize the write cost, we want to maximize `kc`. However, it
// is important that B fits in L1, because batch=1 only has a single row of
// A and thus no reuse of the packed B. When L1-resident, we can use the
// separate `DecompressAndZeroPad` to write `kc` columns, rather than having
// to integrate `Decompress2` into `LoopOverKC`, which is less efficient for
// TB=NUQ due to less amortization of the table loads. Due to the low L1
// latency, the packing is still effectively fused into `LoopOverKC`. It may
// be better to round up and accept a few L2 accesses in exchange for
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
// subtract the output and buf, and allow using more than the actual L1
// size. This results in an overestimate, and the loop below will propose
// the next few smaller values for the autotuner to evaluate.
const size_t bytes_ab = Allocator::L1Bytes() * 3;
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
kc_max =
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
kc_max = HWY_MIN(kc_max, K_);
SizeVec all_kc(1, kc_max);
// Avoid proposing kc > K.
if (K_ > kc_multiple_) {
// Generally it is best to use the full `kc` (fewer writes to `partial`),
// but a bit less can be better if it evenly divides `K`, or enables an
// `mc` that evenly divides `M`. Try several smaller values.
// If we can afford a single K task, that's usually best; only try one
// more. Otherwise, blocks may require smaller kc (more options).
const size_t reps = (kc_max == K_) ? 1 : IsBlock(order) ? 3 : 2;
size_t prev = kc_max;
for (size_t rep = 0; rep < reps; ++rep) {
const size_t div = PrevDivisor(kc_multiple_, prev, K_, kc_multiple_);
prev = div ? div : RoundDownWithFloor(prev / 2, kc_multiple_);
all_kc.push_back(prev);
}
}
if (print_config_ && all_kc.size() > 1) {
fprintf(stderr, "KC: ");
for (size_t kc : all_kc) {
fprintf(stderr, "%zu ", kc);
}
fprintf(stderr, "\n");
}
return all_kc;
}
// The number of (L2 resident) A rows for `A2C0` to loop over.
SizeVec MC(size_t mr, size_t kc, MMOrder order) const {
// Typically 12-24K. The B rows are pinned in L1, but also occupy L2 because
// it is typically inclusive.
const size_t bytes_b = nr_ * kc * (sizeof(SfpStream) + sizeof(BF16));
// Choose the largest feasible `mc_max` (A/C rows) to maximize reuse of the
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
// partial.
const size_t bytes_per_mc = kc * sizeof(BF16) + Allocator::LineBytes();
size_t mc_max = hwy::DivCeil(Allocator::L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
HWY_DASSERT(mc_max != 0);
mc_max = HWY_MIN(mc_max, M_);
mc_max = hwy::RoundDownTo(mc_max, mr);
SizeVec all_mc(1, mc_max);
// Larger MC is better for non-blocks, otherwise we want more small options.
const size_t reps = !IsBlock(order) ? 2 : 3;
size_t prev = mc_max;
for (size_t rep = 0; rep < reps; ++rep) {
prev = PrevDivisor(1, prev, M_, mr);
if (prev >= mc_max || prev == 0) break;
all_mc.push_back(prev);
}
// Blocks: largest is not useful.
if (IsBlock(order) && all_mc.size() > 1) {
all_mc.erase(all_mc.begin(), all_mc.begin() + 1);
}
if (print_config_ && all_mc.size() > 1) {
fprintf(stderr, "MC: ");
for (size_t mc : all_mc) {
fprintf(stderr, "%zu ", mc);
}
fprintf(stderr, "\n");
}
return all_mc;
}
// The number of (possibly L3 resident) B rows per `NT_MT` task.
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
const size_t np_max = ranges_np_.TaskSize();
size_t nc_max = np_max;
const size_t out_bytes = IsOneKC(order) ? sizeof(float) : sizeof(double);
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
// Otherwise, leave it unbounded.
if (M_ > mr) {
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
nc_max = hwy::DivCeil(Allocator::L3Bytes(), bytes_per_nc);
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
}
HWY_DASSERT(nc_max != 0);
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
// If there are going to be multiple ranges, anything more than half would
// be imbalanced and suboptimal.
if (nc_max < np_max && nc_max >= np_max / 2) {
nc_max = RoundDownWithFloor(np_max / 2, nc_multiple_);
}
// Non-block calls ForNP, which ignores `range_nc` and uses `range_np`.
if (!IsBlock(order)) return SizeVec(1, np_max);
SizeVec all_nc(1, nc_max);
// Avoid proposing nc > N.
if (np_max > nc_multiple_) {
// Large L3, but its behavior and characteristics varies across platforms,
// hence autotune a wider range of nc than the other dimensions.
size_t reps = 10;
// For small M, we can afford larger NC, hence allow fewer small options.
if (M_ <= 2 * mr) reps -= 1;
size_t prev = nc_max;
for (size_t rep = 0; rep < reps; ++rep) {
const size_t div =
PrevDivisor(nc_multiple_, prev, np_max, nc_multiple_);
prev = div ? div : RoundDownWithFloor(prev / 2, nc_multiple_);
all_nc.push_back(prev);
if (prev == nc_multiple_) break;
}
// Skip the larger values (unlikely to be chosen), keep about 40%.
const ptrdiff_t want_delete =
static_cast<ptrdiff_t>(all_nc.size() * 5 / 9 + 2);
// Keep at least 2.
const ptrdiff_t max_delete =
HWY_MAX(static_cast<ptrdiff_t>(all_nc.size()) - 2, ptrdiff_t{0});
all_nc.erase(all_nc.begin(),
all_nc.begin() + HWY_MIN(want_delete, max_delete));
}
if (print_config_ && all_nc.size() > 1) {
fprintf(stderr, "NC: ");
for (size_t nc : all_nc) {
fprintf(stderr, "%zu ", nc);
}
fprintf(stderr, "\n");
}
return all_nc;
}
// How many tasks per cluster worker. More = smaller tasks, which can lead
// to better load balancing at the cost of higher overhead.
std::vector<int> InnerTasks(MMOrder order) const {
std::vector<int> inner_tasks;
inner_tasks.reserve(3);
inner_tasks.push_back(1);
// Blocks have one task per mc/nc range and ignore this parameter.
if (!IsBlock(order)) {
inner_tasks.push_back(2);
inner_tasks.push_back(4);
}
return inner_tasks;
}
// Whether to parallelize FillC or enable direct writes to C.
std::vector<MMOut> Outs(MMOrder order) const {
std::vector<MMOut> outs;
for (size_t out_idx = 0;; ++out_idx) {
const MMOut out = static_cast<MMOut>(out_idx);
if (StringFromOut(out) == nullptr) return outs; // done
// kParM only makes sense if we have more than one row of A.
if (out == MMOut::kParM && M_ == 1) continue;
// Blocks are already parallelized.
if (out == MMOut::kParM && IsBlock(order)) continue;
// Direct only works for a single kc range.
if ((out == MMOut::kDirect) != IsOneKC(order)) continue;
// For non-block, kCopy does not beat kDirect.
if (out == MMOut::kCopy && IsOneKC(order) && !IsBlock(order)) continue;
outs.push_back(out);
}
}
const size_t M_;
const size_t K_;
const size_t max_mr_;
const size_t nr_;
const size_t kc_multiple_;
const size_t nc_multiple_;
IndexRangePartition ranges_np_;
const bool print_config_;
};
} // namespace
// Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
size_t nr,
const IndexRangePartition& ranges_np,
bool print_config) {
return GenerateCandidates(M, K, N, max_mr, nr, ranges_np, print_config)();
}
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package
// rows for that.
static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) {
size_t np_multiple = Allocator::QuantumBytes() / sizeof(float);
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
// choose a smaller multiple.
if (N % (np_multiple * num_packages)) {
const size_t min_multiple = Allocator::LineBytes() / sizeof(float);
np_multiple =
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
if (HWY_UNLIKELY(np_multiple == 0)) {
np_multiple = min_multiple;
}
// This happens in tests with small N, hence do not assert.
if (N % (np_multiple * num_packages) && N >= 128) {
HWY_WARN("NPMultiple: N=%zu still not divisible by np_multiple=%zu\n", N,
np_multiple);
np_multiple = nr;
}
}
return np_multiple;
}
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
size_t nr) const {
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
return StaticPartition(IndexRange(0, N), num_packages,
NPMultiple(N, nr, num_packages));
}
MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) {
// Ensure Allocator:Init was called.
HWY_ASSERT(Allocator::LineBytes() != 0 && Allocator::VectorBytes() != 0);
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
}
} // namespace gcpp

View File

@ -16,87 +16,668 @@
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ #ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ #define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
// Non-SIMD part of MatMul: parallelization, allocation, and autotuning.
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <vector>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "compression/compress.h" #include "compression/compress.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" #include "util/basics.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/bit_set.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "hwy/per_target.h" // VectorBytes
namespace gcpp { namespace gcpp {
// TODO: remove deprecated typedef.
using Range1D = IndexRange;
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of // The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
// loads, we reuse the same A row for several B columns, which are also loaded // loads, we reuse the same A row for several B columns, which are also loaded
// once for several rows of C. Thus we produce one 'tile' of C at a time of // once for several rows of C. Thus we produce one 'tile' of C at a time of
// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are // dimensions `mr (<= kMaxMR)` x `kNR`. To keep FMA units busy, this should be
// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4 // at least the product of the FMA latency (3..5) times the throughput (2).
// enables the `StoreInterleaved4` transpose in `StoreHorizontalSums`. We assume // This and `mr` are limited by the number of registers, which is generally
// and verify that `C.Cols() % kRegCols == 0`. // 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in
constexpr size_t kRegCols = 4; // `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`.
constexpr size_t kNR = 4;
// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because
// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile.
// In general, `batch_size` (A/C rows) is not a multiple of `kRegRows`. Thus
// functions that load or store a tile are parameterized on `kRowsPerTile`:
// usually `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0).
constexpr size_t kRegRows = kRegCols;
struct CacheSizes {
CacheSizes() = default;
CacheSizes(const BoundedTopology::Cluster& cluster) {
// Assumes each package and cluster has the same cache sizes, and uses
// reasonable defaults if unknown.
l1_bytes = 32 * 1024; // typical size, rarely changes
l2_bytes = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) * 1024;
l3_bytes = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) * 1024;
}
size_t l1_bytes;
size_t l2_bytes;
size_t l3_bytes;
};
class MMParallel { class MMParallel {
public: public:
MMParallel() : pools_(nullptr) {} static constexpr size_t kMaxPackages = 4;
explicit MMParallel(NestedPools& pools) : pools_(&pools) {}
NestedPools& Pools() const { return *pools_; } MMParallel(NestedPools& pools) : pools_(pools) {
hwy::ThreadPool& Pool() const { return pools_->Pool(); } HWY_DASSERT(pools_.NumPackages() <= kMaxPackages);
private:
NestedPools* pools_;
};
// Allocations and threads, shared across MatMul calls.
class MatMulEnv {
public:
explicit MatMulEnv(NestedPools& pools) : parallel(pools) {
const size_t N = hwy::VectorBytes() / sizeof(float);
buf_ = RowVectorBatch<float>(Extents2D(pools.MaxWorkers(), 16 * N));
} }
RowVectorBatch<float>& Buf() { return buf_; } // Used by tests.
NestedPools& Pools() { return pools_; }
MMParallel parallel; // Initial static partitioning of B rows across packages.
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
size_t nr) const;
// TODO: remove once no longer used. // For `BindB` and `BindC`.
NestedPools& Pools() const { return parallel.Pools(); } size_t Node(size_t pkg_idx) const {
hwy::ThreadPool& Pool() const { return parallel.Pool(); } return pools_.Topology().GetCluster(pkg_idx, 0).Node();
}
// Calls `func(pkg_idx)` for each package in parallel.
template <class Func>
void ForPkg(const size_t max_packages, const Func& func) {
pools_.AllPackages().Run(0, HWY_MIN(max_packages, pools_.NumPackages()),
[&](uint64_t task, size_t pkg_idx) {
HWY_DASSERT(task == pkg_idx);
(void)task;
func(pkg_idx);
});
}
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range)`.
template <class Func>
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
size_t pkg_idx, const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
// Single cluster: parallel-for over static partition of `range_np`.
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, 0);
const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
return ParallelizeOneRange(
worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t /*thread*/) {
func(worker_range);
});
}
// Assign each cluster a sub-range of `range_np` (typically hundreds).
const IndexRangePartition nx_ranges =
StaticPartition(range_np, num_clusters, nx_multiple);
ParallelizeOneRange(
nx_ranges, all_clusters,
[&](const IndexRange& nx_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
ParallelizeOneRange(worker_ranges, cluster,
[&](const IndexRange& worker_range,
size_t /*thread*/) { func(worker_range); });
});
}
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
// rows). Calls `func(range_mc, range_nc)`.
template <class Func>
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t pkg_idx,
const Func& func) {
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
// `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers();
// Single (big) cluster: collapse two range indices into one parallel-for
// to reduce the number of fork-joins.
if (num_clusters == 1) {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange(
ranges_nc, cluster,
[&](const IndexRange& range_nc, size_t /*thread*/) {
func(ranges_mc.Range(0), range_nc);
});
} else {
return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t /*thread*/) { func(range_mc, range_nc); });
}
}
// Multiple clusters: N across clusters (both are usually the larger), and
// M within each cluster. We assume auto-tuning finds small MC/NC tasks.
ParallelizeOneRange(
ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) {
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
ParallelizeOneRange(
ranges_mc, cluster,
[&](const IndexRange& range_mc, size_t /*thread*/) {
func(range_mc, range_nc);
});
});
}
// Calls `func(row_a)` in parallel.
template <class Func>
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
const Func& func) {
pools_.Pool(pkg_idx).Run(
range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); });
}
private: private:
RowVectorBatch<float> buf_; NestedPools& pools_;
}; };
template <typename T> // float for C, double for partial
void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) {
if (!Allocator::ShouldBind()) return;
const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(T);
bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx);
for (size_t im = 0; im < M; ++im) {
// BindRowsToPackageNodes may not be page-aligned.
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(T),
node);
}
}
if (HWY_UNLIKELY(!ok)) {
HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", M, C.Cols(),
ranges_np.NumTasks());
}
}
// Per-package storage for packed A, and one global C-shaped `partial` for
// accumulating partial dot products (sections of K).
class MMStorage {
public:
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
// per package and 512 MiB, respectively.
static constexpr size_t kMaxM = 2048;
static constexpr size_t kMaxK = 64 * 1024;
static constexpr size_t kMaxN = 256 * 1024;
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024;
explicit MMStorage(MMParallel& parallel) {
// Per-package allocation so each can decompress A into its own copy.
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx] = AllocateAlignedRows<BF16>(Extents2D(kMaxM, kMaxK));
if (Allocator::ShouldBind()) {
const size_t node = parallel.Node(pkg_idx);
if (!Allocator::BindMemory(pkg_A_[pkg_idx].All(),
pkg_A_[pkg_idx].NumBytes(), node)) {
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
}
}
});
// Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity.
partial_storage_ = AllocateAlignedRows<double>(Extents2D(kMaxM, kMaxN));
// Same stride independent of the actual C.Cols() so we can pre-bind.
partial_ = RowPtrD(partial_storage_.All(), kMaxN,
StrideForCyclicOffsets<double>(kMaxN));
// Avoid cross-package accesses.
BindC(kMaxM, partial_, parallel);
}
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
// non-const, because `RowPtr` requires a non-const pointer.
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) {
HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK);
const size_t stride = StrideForCyclicOffsets<BF16>(extents.cols);
return RowPtrBF(pkg_A_[pkg_idx].All(), extents.cols, stride);
}
RowPtrD Partial() const { return partial_; }
private:
RowVectorBatch<BF16> pkg_A_[MMParallel::kMaxPackages];
RowVectorBatch<double> partial_storage_;
RowPtrD partial_;
};
//------------------------------------------------------------------------------
// Autotuning
// Naming convention: outer loop first, T suffix means threaded. This refers to
// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost
// `ranges_np` loop across packages is implicit and applies to all of these.
//
// Parallelizing across K (A/B columns) is undesirable because the resulting
// partial dot products require synchronization or reduction across threads.
enum class MMOrder : uint8_t {
// Single M, parallel N, sequential K (inside the parallel section to
// reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K.
kNT_K,
// Specialization of `kNT_K` for a single K task with `kDirect`.
kNT,
// Parallelize over blocks of M and N: good when both are large. We no longer
// support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as
// fast on Zen4.
kNT_MT_K,
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`.
// Resident C (`kK_M_NT`) should be good for large K relative to M and N.
// However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are
// no kN* because we expect M (batch size) to be small relative to K and N.
};
static inline bool IsBlock(MMOrder order) {
return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT;
}
static inline bool IsOneKC(MMOrder order) {
return order == MMOrder::kNT || order == MMOrder::kNT_MT;
}
static inline const char* StringFromOrder(MMOrder order) {
switch (order) {
case MMOrder::kNT_K:
return "NT_K";
case MMOrder::kNT:
return "NT";
case MMOrder::kNT_MT_K:
return "NT_MT_K";
case MMOrder::kNT_MT:
return "NT_MT";
default:
return nullptr;
}
}
// How/where to write the A2C0 result. This determines the `tag` argument to
// that function, which governs whether we call `MMStoreHorizontalSumsIntoC` or
// `MMAddHorizontalSumsIntoPartial`.
enum class MMOut : uint8_t {
kCopy, // accumulate into partial, scale/add to C
kDirect, // single kc task, write directly to C
kParM // kCopy but parallel over M
// kParN is not better on SKX/Zen4.
};
static inline const char* StringFromOut(MMOut out) {
switch (out) {
case MMOut::kDirect:
return "Direct";
case MMOut::kCopy:
return "Copy";
case MMOut::kParM:
return "ParM";
default:
return nullptr;
}
}
// How to parallelize the per-package `DecompressA`. To reduce combinatorial
// explosion, we tune this separately from `MMConfig`.
enum class MMParA : uint8_t { kNone, kK1 = 1, kK2 = 2, kK4 = 4, kM };
static inline const char* StringFromParA(MMParA par_a) {
switch (par_a) {
case MMParA::kNone:
return "ParA0 ";
case MMParA::kK1:
return "ParAK1";
case MMParA::kK2:
return "ParAK2";
case MMParA::kK4:
return "ParAK4";
case MMParA::kM:
return "ParAM ";
default:
return nullptr;
}
}
// Possible configurations for the autotuner to choose from:
// `mr` := C rows to write at a time (< #registers / `kNR`),
// `kc` := A / B columns such that `mr` rows fit in L1,
// `mc` := A rows such that `kc` columns fit in L2,
// `nc` := B rows such that `kc` columns fit in L3 alongside `mc x nc` C.
// Also includes loop order and task granularity.
#pragma pack(push, 1)
class MMConfig {
public:
MMConfig() = default; // for std::vector
// `mr` is the number of A rows per call to `MMKernel::LoopOverKC`.
// `MMOrder` is how to parallelize the outer loops.
// `MMOut` is how/whether to parallelize filling the C result.
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
MMConfig(size_t K, size_t mr, size_t mc, size_t kc, size_t nc,
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
int inner_tasks)
: mr_(static_cast<uint32_t>(mr)),
mc_(static_cast<uint32_t>(mc)),
kc_(static_cast<uint32_t>(kc)),
nc_(static_cast<uint32_t>(nc)),
nc_multiple_(static_cast<uint32_t>(nc_multiple)),
kc_multiple_(static_cast<uint32_t>(kc_multiple)),
order_(order),
out_(out),
inner_tasks_(static_cast<uint8_t>(inner_tasks)),
reserved_{} {
HWY_DASSERT(mr == 1 || mr == 2 || mr == 4);
if (mc % mr != 0) {
HWY_WARN("mc %zu not a multiple of mr %zu", mc, mr);
}
// Do not warn for single-kc tasks; some models unfortunately have K which
// are not multiples of `kc_multiple`.
if (kc != K && (kc % kc_multiple) != 0) {
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
}
if (nc % nc_multiple != 0) {
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
}
HWY_DASSERT(StringFromOrder(order_) != nullptr);
HWY_DASSERT(StringFromOut(out_) != nullptr);
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
}
// Splits M/N into blocks which are visited sequentially or in parallel.
// K is always sequential, see `MMOrder`.
IndexRangePartition RangesOfMC(size_t M) const {
return MaxSizePartition(IndexRange(0, M), mc_, mr_);
}
IndexRangePartition RangesOfKC(size_t K) const {
return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_);
}
IndexRangePartition RangesOfNC(IndexRange range_np) const {
return MaxSizePartition(range_np, nc_, nc_multiple_);
}
MMOrder Order() const { return order_; }
MMOut Out() const { return out_; }
// No `OuterTasks` because static partitioning across clusters is sufficient.
size_t InnerTasks() const { return static_cast<size_t>(inner_tasks_); }
// Accessors for printing autotune result.
size_t MR() const { return static_cast<size_t>(mr_); }
size_t MC() const { return static_cast<size_t>(mc_); }
size_t KC() const { return static_cast<size_t>(kc_); }
size_t NC() const { return static_cast<size_t>(nc_); }
private:
// Somewhat-compressed representation because MMCandidates may return dozens.
uint32_t mr_;
uint32_t mc_;
uint32_t kc_;
uint32_t nc_;
uint32_t nc_multiple_;
uint32_t kc_multiple_;
MMOrder order_;
MMOut out_;
uint8_t inner_tasks_;
HWY_MAYBE_UNUSED uint8_t reserved_[5];
};
static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
size_t nr,
const IndexRangePartition& ranges_np,
bool print_config);
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
// main MatMul autotuner.
template <typename TConfig>
class MMAutoTune {
public:
// Returns nullptr if not yet finished, otherwise the best config. Do not
// store this pointer because it can be invalidated.
const TConfig* Best() const { return best_; }
// If false, caller must call `SetCandidates` before `NextConfig`.
bool HasCandidates() const {
HWY_DASSERT(!Best());
return !candidates_.empty();
}
void SetCandidates(std::vector<TConfig> candidates) {
HWY_DASSERT(!HasCandidates());
candidates_.swap(candidates);
HWY_DASSERT(HasCandidates());
min_ticks_.resize(candidates_.size(), ~uint64_t{0});
}
// Returns the current `TConfig` to measure.
const TConfig& NextConfig() const {
HWY_DASSERT(!Best() && HasCandidates());
return candidates_[config_idx_];
}
// Returns the best ticks so far for this candidate. Negligible CPU time.
uint64_t NotifyTicks(uint64_t ticks) {
HWY_DASSERT(HasCandidates());
HWY_DASSERT(!skipped_.Get(config_idx_));
best_ticks_ = HWY_MIN(best_ticks_, ticks);
min_ticks_[config_idx_] = HWY_MIN(min_ticks_[config_idx_], ticks);
// Best so far. Save because we update `config_idx_` below.
const size_t my_best_ticks = min_ticks_[config_idx_];
const size_t my_idx = config_idx_;
// Advance/wrap around to next non-skipped config. Do this first because it
// updates `rounds_complete_`. To decorrelate measurements, we do not
// immediately re-measure the same config.
for (;;) {
++config_idx_;
if (HWY_UNLIKELY(config_idx_ == candidates_.size())) {
config_idx_ = 0;
++rounds_complete_;
}
// Guaranteed to terminate because `best_ticks_` is never worse than any
// other, hence is not skipped.
if (!skipped_.Get(config_idx_)) break;
}
// Disqualify from future `NextConfig` if the best of two measurements so
// far is sufficiently worse than `best_ticks_`. This tolerates some noise
// in the first or second measurement.
if (rounds_complete_ != 0 && my_best_ticks > 5 * best_ticks_ / 4) {
skipped_.Set(my_idx);
}
// After sufficient rounds, choose the winner.
if (rounds_complete_ == 4) {
for (size_t i = 0; i < candidates_.size(); ++i) {
worst_min_ticks_ = HWY_MAX(worst_min_ticks_, min_ticks_[i]);
if (min_ticks_[i] == best_ticks_) {
// Causes `Best()` to be non-null, hence `MatMul` will no longer call
// `NextConfig` for this shape.
best_ = &candidates_[i];
config_idx_ = i; // just in case callers want to know which index.
}
}
HWY_DASSERT(best_ != nullptr); // no min_ticks_ matches best_ticks_
}
return my_best_ticks;
}
// Avoid printing the first two rounds, because those might be noisy and not
// yet skipped.
bool ShouldPrint() { return rounds_complete_ > 2; }
// Only valid after Best() is non-null. Used to compute the autotuning gain.
uint64_t BestTicks() const { return best_ticks_; }
uint64_t WorstMinTicks() const { return worst_min_ticks_; }
uint64_t FirstConfigTicks() const { return min_ticks_[0]; }
private:
const TConfig* best_ = nullptr;
std::vector<TConfig> candidates_;
// Use Min because threads are pinned, so we only expect additive noise.
std::vector<uint64_t> min_ticks_; // one per candidate
size_t config_idx_ = 0; // [0, candidates_.size())
size_t rounds_complete_ = 0;
uint64_t best_ticks_ = ~uint64_t{0};
uint64_t worst_min_ticks_ = 0;
hwy::BitSet4096<> skipped_;
};
//------------------------------------------------------------------------------
// Map of previously seen dimensions to index via linear search.
class MMKeys {
public:
using Key = uint64_t;
// KeyFromDims will only return this if all dims are zero, which is invalid.
static constexpr Key kPadding = 0;
// Compresses the dimensions into a single Key for faster comparison.
static Key KeyFromDims(size_t M, size_t K, size_t N) {
HWY_DASSERT(M < (Key{1} << 16)); // batch sizes are smaller
HWY_DASSERT(K < (Key{1} << 24));
HWY_DASSERT(N < (Key{1} << 24));
const Key key = static_cast<Key>(M) | (static_cast<Key>(K) << 16) |
(static_cast<Key>(N) << 40);
HWY_DASSERT(key != kPadding);
return key;
}
// We leave the search to callers so they can use dynamic-dispatched SIMD,
// which is not possible in this header.
hwy::Span<const Key> Keys() const {
return hwy::Span<const Key>(keys_.get(), num_unique_);
}
// Must only be called if not already present in `Keys()`.
void Append(Key key) {
// Dynamic allocation because the test checks many more dimensions than
// would be reasonable to pre-allocate. DIY for alignment and padding.
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
const size_t NU64 = Allocator::VectorBytes() / sizeof(Key);
// Start at one vector so the size is always a multiple of N.
if (HWY_UNLIKELY(capacity_ == 0)) {
capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below
}
capacity_ *= 2;
HWY_DASSERT(capacity_ >= num_unique_ + 1);
hwy::AlignedFreeUniquePtr<Key[]> new_keys =
hwy::AllocateAligned<Key>(capacity_);
hwy::CopyBytes(keys_.get(), new_keys.get(), num_unique_ * sizeof(Key));
// Pad for SIMD.
for (size_t i = num_unique_; i < hwy::RoundUpTo(num_unique_, NU64); ++i) {
new_keys[i] = kPadding;
}
keys_.swap(new_keys);
}
keys_[num_unique_++] = key;
}
private:
size_t capacity_ = 0;
size_t num_unique_ = 0;
hwy::AlignedFreeUniquePtr<Key[]> keys_;
};
// Per-MatMul-shape state.
struct MMPerKey {
MMPerKey(size_t max_packages, size_t N, size_t nr, MMParallel& parallel)
: ranges_np(parallel.RangesOfNP(max_packages, N, nr)) {}
// Only profile if enabled and the main autotuner finished (the par_a
// autotuner is per-package and we want to avoid synchronization).
bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); }
const IndexRangePartition ranges_np;
MMAutoTune<MMConfig> autotune;
MMAutoTune<MMParA> autotune_par_a[MMParallel::kMaxPackages];
};
// Stores state shared across MatMul calls. Non-copyable.
struct MatMulEnv {
explicit MatMulEnv(NestedPools& pools);
bool have_timer_stop = false;
// Enable binding: disabled in Gemma until tensors support it, enabled in
// bench_matmul.cc.
bool enable_bind = false;
// Whether `MMCandidates()` should print the set of parameters.
bool print_config = false;
// Whether to print each config's speed during autotuning.
bool print_measurement = false;
// Whether to print the best config immediately after autotuning finished.
bool print_best = false;
MMParallel parallel;
MMStorage storage;
MMKeys keys;
std::vector<MMPerKey> per_key;
};
// Arguments to MatMul() that are independent of the A/B type.
// Reduces register pressure compared to individual values/references.
struct MMArgs {
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
const float* HWY_RESTRICT add, const RowPtrD& partial,
const RowPtrF& C)
: env(&env),
per_key(&per_key),
scale(scale),
add(add),
partial(partial),
C(C) {}
MatMulEnv* env;
MMPerKey* per_key;
double scale;
const float* HWY_RESTRICT add;
// Same size as C, threads write at false-sharing-free granularity.
RowPtrD partial;
RowPtrF C;
};
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
#if PROFILER_ENABLED
class MMZone {
using Zone = hwy::Zone;
static_assert(alignof(Zone) <= 8 && sizeof(Zone) <= 8);
public:
~MMZone() {
if (used_) {
Zone* zone = reinterpret_cast<Zone*>(&data_);
zone->~Zone();
}
}
// `name` must be a string literal.
void MaybeEnter(const char* name, const MMArgs& args) {
if (args.per_key->WantProfile()) {
new (&data_) Zone(name);
used_ = true;
}
}
private:
uint64_t data_ = 0;
bool used_ = false;
};
#else
struct MMZone {
void MaybeEnter(const char*, const MMArgs&) {}
};
#endif // PROFILER_ENABLED
// Used for the A and B arguments of `MatMul`, which are always const. // Used for the A and B arguments of `MatMul`, which are always const.
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the // Create via MakeConstMat. This differs from `RowPtr` in that it supports the
// `ofs` required for compressed T. // `ofs` required for compressed T.
@ -161,6 +742,29 @@ ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
return mat; return mat;
} }
template <typename TB>
void BindB(size_t N, const ConstMat<TB>& B, MMParallel& parallel) {
if (!Allocator::ShouldBind()) return;
const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, N, kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx);
uintptr_t begin =
reinterpret_cast<uintptr_t>(B.ptr + B.Row(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB);
// B is not yet guaranteed to have padded rows, so only bind the
// subset that is page-aligned.
begin = hwy::RoundUpTo(begin, quantum);
end = hwy::RoundDownTo(end, quantum);
if (HWY_LIKELY(begin != end)) {
Allocator::BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
}
}
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ #endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_

View File

@ -243,17 +243,20 @@ template <typename TA, typename TB = TA>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env) { MatMulEnv& env) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool(); hwy::ThreadPool& pool = env.parallel.Pools().Pool();
fprintf(stderr, "TestMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac, fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac,
cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>()); cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>());
env.print_config = true;
env.print_best = true;
const Extents2D A_extents(rows_ac, cols_a_rows_b); const Extents2D A_extents(rows_ac, cols_a_rows_b);
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc); const Extents2D C_extents(rows_ac, cols_bc);
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool); MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool); MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
RowVectorBatch<float> c_slow_batch(C_extents); RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
RowVectorBatch<float> c_batch(C_extents); RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
HWY_ASSERT(a && b_trans); HWY_ASSERT(a && b_trans);
std::unique_ptr<MatStorageT<float>> add_storage; std::unique_ptr<MatStorageT<float>> add_storage;
@ -270,8 +273,12 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const RowPtrF C = RowPtrFromBatch(c_batch); const RowPtrF C = RowPtrFromBatch(c_batch);
MatMulSlow(A, B, add_row, env, C_slow); MatMulSlow(A, B, add_row, env, C_slow);
MatMul(A, B, add_row, env, C); // A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
AssertClose(A, B, C_slow, C); AssertClose(A, B, C_slow, C);
if (per_key->autotune.Best()) break;
}
} }
using F32 = float; using F32 = float;
@ -298,13 +305,12 @@ void TestTiny() {
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning); pools.MaybeStartSpinning(use_spinning);
Allocator::Init(pools.Topology()); Allocator::Init(pools.Topology(), /*enable_bind=*/true);
MatMulEnv env(pools); MatMulEnv env(pools);
for (size_t M = 1; M <= 3 * kRegRows; ++M) { for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 64; K <= 128; K *= 2) { for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = /*kRegRows*/ 16; N <= 64; for (size_t N = 4; N <= 64; N += max_packages * 4) {
N += max_packages * kRegRows) {
TestMatMul<F32, F32>(M, K, N, /*add=*/false, env); TestMatMul<F32, F32>(M, K, N, /*add=*/false, env);
} }
} }
@ -323,7 +329,7 @@ void TestAllMatMul() {
NestedPools pools(0); // no limits NestedPools pools(0); // no limits
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning); pools.MaybeStartSpinning(use_spinning);
Allocator::Init(pools.Topology()); Allocator::Init(pools.Topology(), /*enable_bind=*/true);
MatMulEnv env(pools); MatMulEnv env(pools);
// Sizes seen in gemma_test 2B. // Sizes seen in gemma_test 2B.

View File

@ -1,17 +0,0 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO: Tests of individual MatMul components.
int main() { return 0; }

View File

@ -35,6 +35,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/app.h"
#include "util/test_util.h" #include "util/test_util.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
@ -386,6 +387,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
} }
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
NestedPools pools = CreatePools(AppArgs());
Allocator::Init(pools.Topology());
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B); ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
int dim_qkv = config.layer_configs[0].qkv_dim; int dim_qkv = config.layer_configs[0].qkv_dim;
RowVectorBatch<float> x(Extents2D(1, dim_qkv)); RowVectorBatch<float> x(Extents2D(1, dim_qkv));

View File

@ -34,6 +34,11 @@
#endif #endif
#ifndef GEMMA_BIND // allow override #ifndef GEMMA_BIND // allow override
// OSes will generally do the right thing when threads allocate their own
// working memory. However, matmul's B and C matrices are preferably sharded
// across NUMA nodes. To simplify the matrix representation, we prefer a
// single allocation. This requires page-level control over the memory layout,
// which Linux provides via `move_pages`, but Windows does not.
#if defined(GEMMA_LINUX_SYSCALL6) && !defined(__ANDROID_API__) #if defined(GEMMA_LINUX_SYSCALL6) && !defined(__ANDROID_API__)
#define GEMMA_BIND 1 #define GEMMA_BIND 1
#else #else
@ -93,7 +98,7 @@ size_t Allocator::L2Bytes() { return l2_bytes_; }
size_t Allocator::L3Bytes() { return l3_bytes_; } size_t Allocator::L3Bytes() { return l3_bytes_; }
bool Allocator::ShouldBind() { return should_bind_; } bool Allocator::ShouldBind() { return should_bind_; }
void Allocator::Init(const BoundedTopology& topology) { void Allocator::Init(const BoundedTopology& topology, bool enable_bind) {
line_bytes_ = DetectLineBytes(); line_bytes_ = DetectLineBytes();
vector_bytes_ = hwy::VectorBytes(); vector_bytes_ = hwy::VectorBytes();
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
@ -122,6 +127,7 @@ void Allocator::Init(const BoundedTopology& topology) {
const size_t page_bytes = DetectPageSize(); const size_t page_bytes = DetectPageSize();
if ((page_bytes != 0 && page_bytes <= 16 * 1024) && if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
topology.NumNodes() > 1 && topology.NumPackages() > 1) { topology.NumNodes() > 1 && topology.NumPackages() > 1) {
if (enable_bind) {
// Ensure pages meet the alignment requirements of `AllocBytes`. // Ensure pages meet the alignment requirements of `AllocBytes`.
HWY_ASSERT(page_bytes >= quantum_bytes_); HWY_ASSERT(page_bytes >= quantum_bytes_);
quantum_bytes_ = page_bytes; quantum_bytes_ = page_bytes;
@ -129,6 +135,11 @@ void Allocator::Init(const BoundedTopology& topology) {
HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_); HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_);
quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes()); quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes());
should_bind_ = true; should_bind_ = true;
} else {
HWY_WARN(
"Multiple sockets but binding disabled. This reduces speed; "
"set or remove enable_bind to avoid this warning.");
}
} }
} }

View File

@ -16,6 +16,8 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_
// Allocator with support for sharding tensors across NUMA nodes.
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@ -65,7 +67,8 @@ class Allocator {
public: public:
// Must be called at least once before any other function. Not thread-safe, // Must be called at least once before any other function. Not thread-safe,
// hence only call this from the main thread. // hence only call this from the main thread.
static void Init(const BoundedTopology& topology); // TODO: remove enable_bind once Gemma tensors support binding.
static void Init(const BoundedTopology& topology, bool enable_bind = false);
// Bytes per cache line, or a reasonable guess if unknown. Used to choose // Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing. // ranges such that there will be no false sharing.
@ -80,8 +83,10 @@ class Allocator {
static constexpr size_t MaxQuantumBytes() { return 4096; } static constexpr size_t MaxQuantumBytes() { return 4096; }
static size_t QuantumSteps(); // = QuantumBytes() / StepBytes() static size_t QuantumSteps(); // = QuantumBytes() / StepBytes()
// L1 and L2 are typically per core.
static size_t L1Bytes(); static size_t L1Bytes();
static size_t L2Bytes(); static size_t L2Bytes();
// Clusters often share an L3. We return the total size per package.
static size_t L3Bytes(); static size_t L3Bytes();
// Returns pointer aligned to `QuantumBytes()`. // Returns pointer aligned to `QuantumBytes()`.
@ -119,6 +124,19 @@ class Allocator {
static PtrAndDeleter AllocBytes(size_t bytes); static PtrAndDeleter AllocBytes(size_t bytes);
}; };
// Value of `stride` to pass to `RowVectorBatch` to enable the "cyclic offsets"
// optimization. If `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is
// typically 4KiB. To avoid remote accesses, we would thus pad each row to that,
// which results in 4K aliasing and/or cache conflict misses. `RowPtr` is able
// to prevent that by pulling rows forward by a cyclic offset, which is still a
// multiple of the cache line size. This requires an additional
// `Allocator::QuantumBytes()` of padding after also rounding up to that.
template <typename T>
constexpr size_t StrideForCyclicOffsets(size_t cols) {
const size_t quantum = Allocator::MaxQuantumBytes() / sizeof(T);
return hwy::RoundUpTo(cols, quantum) + quantum;
}
// Owns dynamically-allocated aligned memory for a batch of row vectors. // Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns // This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory. // the memory.
@ -130,6 +148,7 @@ class RowVectorBatch {
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default, // Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
// we default to tightly packed rows (`stride = cols`). // we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols. // WARNING: not all call sites support `stride` != cols.
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) { RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
if (stride == 0) { if (stride == 0) {
stride_ = extents_.cols; stride_ = extents_.cols;
@ -137,7 +156,10 @@ class RowVectorBatch {
HWY_ASSERT(stride >= extents_.cols); HWY_ASSERT(stride >= extents_.cols);
stride_ = stride; stride_ = stride;
} }
mem_ = Allocator::Alloc<T>(extents_.rows * stride_); // Allow binding the entire matrix.
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
Allocator::QuantumBytes() / sizeof(T));
mem_ = Allocator::Alloc<T>(padded);
} }
// Move-only // Move-only
@ -186,6 +208,11 @@ static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
return padded_num; return padded_num;
} }
template <typename T>
RowVectorBatch<T> AllocateAlignedRows(Extents2D extents) {
return RowVectorBatch<T>(extents, StrideForCyclicOffsets<T>(extents.cols));
}
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because // Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
// it is always float and does not support compressed T, but does support an // it is always float and does not support compressed T, but does support an
// arbitrary stride >= cols. // arbitrary stride >= cols.
@ -202,7 +229,19 @@ class RowPtr {
row_mask_(Allocator::QuantumSteps() - 1) { row_mask_(Allocator::QuantumSteps() - 1) {
HWY_DASSERT(stride >= cols); HWY_DASSERT(stride >= cols);
HWY_DASSERT(row_mask_ != ~size_t{0}); HWY_DASSERT(row_mask_ != ~size_t{0});
row_mask_ = 0; // TODO: remove if constexpr (HWY_IS_DEBUG_BUILD) {
if (stride < StrideForCyclicOffsets<T>(cols)) {
static bool once;
if (!once) {
once = true;
HWY_WARN(
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
"T=%zu; this forces us to disable cyclic offsets.",
stride, cols, sizeof(T));
}
row_mask_ = 0;
}
}
} }
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {} RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}

View File

@ -295,6 +295,19 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.max_generated_tokens = max_generated_tokens; runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size;
if (prefill_tbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
prefill_tbatch_size, MMStorage::kMaxM);
}
if (decode_qbatch_size > MMStorage::kMaxM) {
HWY_ABORT(
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
"or increase the constant in MMStorage.\n",
decode_qbatch_size, MMStorage::kMaxM);
}
runtime_config.temperature = temperature; runtime_config.temperature = temperature;
runtime_config.top_k = top_k; runtime_config.top_k = top_k;
} }