mirror of https://github.com/google/gemma.cpp.git
1.31x batch prefill, 1.24x batch decode speedup: NUMA binding
Only the weights; binding MatMul output worsens batch=1 prefill. Update gemma_batch_bench to use --decode_qbatch. Fix/remove prefill_activations in gemma-inl.h. Refactor: use BasePageBytes directly when binding Move BindB/C to .cc by de-templatizing Remove MatOwners::AllocateFor because it is weights-specific (binding or not) Disband MatOwners, replace with vector PiperOrigin-RevId: 759610477
This commit is contained in:
parent
c443adee33
commit
e890d46f30
|
|
@ -172,7 +172,6 @@ cc_library(
|
|||
"//io:fields",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -217,6 +216,7 @@ cc_library(
|
|||
":configs",
|
||||
":mat",
|
||||
":model_store",
|
||||
":ops",
|
||||
":tensor_info",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
|
|
@ -281,7 +281,6 @@ cc_library(
|
|||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
"@highway//:algo",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gemma/configs.h"
|
||||
|
|
@ -75,7 +76,7 @@ class WeightsWrapper {
|
|||
ModelWeightsPtrs<T>& get() { return weights_; }
|
||||
|
||||
private:
|
||||
MatOwners owners_;
|
||||
std::vector<MatOwner> owners_;
|
||||
ModelWeightsPtrs<T> weights_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ class SbsWriterImpl : public ISbsWriter {
|
|||
}
|
||||
|
||||
mat.AppendTo(serialized_mat_ptrs_);
|
||||
mat_owners_.AllocateFor(mat, MatPadding::kPacked);
|
||||
mat_owners_.push_back(MatOwner());
|
||||
mat_owners_.back().AllocateFor(mat, MatPadding::kPacked);
|
||||
|
||||
// Handle gemma_export_test's MockArray. Write blobs so that the test
|
||||
// succeeds, but we only have 10 floats, not the full tensor.
|
||||
|
|
@ -121,7 +122,7 @@ class SbsWriterImpl : public ISbsWriter {
|
|||
}
|
||||
|
||||
hwy::ThreadPool& pool_;
|
||||
MatOwners mat_owners_;
|
||||
std::vector<MatOwner> mat_owners_;
|
||||
CompressWorkingSet working_set_;
|
||||
BlobWriter writer_;
|
||||
std::vector<uint32_t> serialized_mat_ptrs_;
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@
|
|||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// This test can be run manually with the downloaded gemma weights.
|
||||
// To run the test, pass the following flags:
|
||||
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
||||
// --tokenizer <tokenizer_path> --weights <weights_path>
|
||||
// It should pass for the following models:
|
||||
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
|
||||
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -76,26 +76,25 @@ class GemmaTest : public ::testing::Test {
|
|||
return replies;
|
||||
}
|
||||
|
||||
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
|
||||
void GenerateTokens(const std::vector<std::string>& questions) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
|
||||
// Fills prompts round robin from `questions` until the desired batch size.
|
||||
std::vector<std::string> inputs;
|
||||
inputs.reserve(num_questions);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
inputs.push_back(kQA[i]);
|
||||
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
|
||||
size_t qpos = 0;
|
||||
for (size_t i = 0; i < inputs.capacity(); ++i) {
|
||||
inputs.push_back(questions[qpos++]);
|
||||
if (qpos == questions.size()) qpos = 0;
|
||||
}
|
||||
std::vector<std::string> responses = BatchGemmaReply(inputs);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
std::string response = responses.at(i);
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GemmaTest, RandomQuestionsBatched) {
|
||||
s_env->MutableConfig().decode_qbatch_size = 3;
|
||||
s_env->MutableConfig().verbosity = 5;
|
||||
|
||||
static std::vector<std::string> kQA = {
|
||||
{"Write me a poem about Australia?"},
|
||||
{"What's the history of Denmark?"},
|
||||
|
|
@ -130,8 +129,9 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
|
|||
{"Tell me about space travel."},
|
||||
{"Explain to me how electric cars work."},
|
||||
};
|
||||
static const size_t kNum = kQA.size();
|
||||
GenerateTokens(kQA, kNum);
|
||||
s_env->MutableConfig().verbosity = 5;
|
||||
GenerateTokens(kQA);
|
||||
PROFILER_PRINT_RESULTS();
|
||||
}
|
||||
} // namespace
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@ struct Activations {
|
|||
|
||||
env(env) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <cstdio>
|
||||
#include <memory> // std::make_unique
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
|
|
@ -1055,7 +1055,8 @@ HWY_NOINLINE void Prefill(
|
|||
// intensity, and so we are eventually compute-limited. We could devote some
|
||||
// threads to parallelizing over queries, but for simplicity we assign them
|
||||
// all to MatMul.
|
||||
const size_t max_tbatch_size = activations.x.Rows();
|
||||
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
||||
HWY_DASSERT(max_tbatch_size <= activations.x.Rows());
|
||||
|
||||
// For each query. `qi` is within the batch, not the global query index.
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
|
|
@ -1429,18 +1430,10 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
// Prefill stops before min_prompt_size - 1 because the last prompt
|
||||
// token is the first input token for generation.
|
||||
timing_info.prefill_start = hwy::platform::Now();
|
||||
// If tbatch is larger than the qbatch we already have in `activations`, then
|
||||
// allocate prefill_activations, otherwise reuse.
|
||||
const bool use_prefill_activations =
|
||||
runtime_config.prefill_tbatch_size > activations.x.Rows();
|
||||
Activations prefill_activations(
|
||||
weights.weights_config,
|
||||
use_prefill_activations ? runtime_config.prefill_tbatch_size : 0,
|
||||
activations.env);
|
||||
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
|
||||
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
|
||||
query_idx_start, weights,
|
||||
use_prefill_activations ? prefill_activations : activations,
|
||||
runtime_config, div_seq_len, kv_caches);
|
||||
query_idx_start, weights, activations, runtime_config, div_seq_len,
|
||||
kv_caches);
|
||||
// Compute the number of tokens that were prefilled and notify timing_info.
|
||||
size_t prefilled_tokens = 0;
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
|
|
@ -1448,6 +1441,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
|||
}
|
||||
timing_info.NotifyPrefill(prefilled_tokens);
|
||||
// queries_pos are incremented by Prefill.
|
||||
activations.SetBatchSize(num_queries);
|
||||
|
||||
// Storage for the last generated token from each query, passed to the next
|
||||
// Transformer() call.
|
||||
|
|
@ -1489,8 +1483,10 @@ void GenerateSingleT(const ModelStore& model,
|
|||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
||||
const size_t max_batch_size =
|
||||
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
|
||||
// TODO: move into Gemma?
|
||||
Activations activations(model.Config(), kNumQueries, env);
|
||||
Activations activations(model.Config(), max_batch_size, env);
|
||||
|
||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
||||
QueriesPos queries_pos(&pos, kNumQueries);
|
||||
|
|
@ -1523,7 +1519,9 @@ void GenerateBatchT(const ModelStore& model,
|
|||
}
|
||||
}
|
||||
|
||||
Activations activations(model.Config(), max_qbatch_size, env);
|
||||
const size_t max_batch_size =
|
||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
||||
Activations activations(model.Config(), max_batch_size, env);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
|
|
@ -1557,6 +1555,7 @@ void GenerateImageTokensT(const ModelStore& model,
|
|||
prefill_runtime_config.prefill_tbatch_size =
|
||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, vit_config.seq_len, env);
|
||||
prefill_activations.SetBatchSize(prefill_runtime_config.prefill_tbatch_size);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations);
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@
|
|||
#include "gemma/weights.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
|
@ -30,6 +30,7 @@
|
|||
#include "gemma/configs.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "ops/matmul.h" // MMParallel
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -46,7 +47,7 @@ namespace gcpp {
|
|||
static void InitAttWeightsNUQ(const LayerConfig& layer_config,
|
||||
MatPtrT<NuqStream>& attn_vec_einsum_w,
|
||||
MatPtrT<NuqStream>& att_weights,
|
||||
MatOwners& mat_owners) {
|
||||
std::vector<MatOwner>& mat_owners) {
|
||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
|
||||
|
||||
|
|
@ -91,11 +92,29 @@ static void SplitW1NUQ(const LayerConfig& layer_config) {
|
|||
}
|
||||
|
||||
template <>
|
||||
void LayerWeightsPtrs<NuqStream>::Fixup(MatOwners& mat_owners) {
|
||||
void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
|
||||
InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners);
|
||||
SplitW1NUQ(layer_config);
|
||||
}
|
||||
|
||||
// Allocates multiple in parallel and binds to NUMA nodes.
|
||||
static void AllocateAndBindAll(const std::vector<MatPtr*>& mats,
|
||||
MatPadding padding,
|
||||
std::vector<MatOwner>& owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
const size_t start = owners.size();
|
||||
owners.resize(start + mats.size());
|
||||
|
||||
MMParallel parallel(ThreadingContext::Get());
|
||||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
owners[start + task].AllocateFor(*mats[task], padding);
|
||||
// TODO(janwas): MatMul outputs will later also be BF16.
|
||||
BindB(*mats[task], sizeof(float), parallel);
|
||||
});
|
||||
}
|
||||
|
||||
// Parallel I/O into allocated memory, or mapped view of file. The latter is
|
||||
// better when the file is huge, but page faults add noise to measurements.
|
||||
enum class Mode { kRead, kMap };
|
||||
|
|
@ -209,10 +228,10 @@ static void ReadBatches(const BlobReader& reader,
|
|||
}
|
||||
|
||||
// Aborts on error.
|
||||
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||
const std::vector<BlobRange>& ranges, Tristate map,
|
||||
MatOwners& mat_owners, const MatPadding padding,
|
||||
hwy::ThreadPool& pool) {
|
||||
static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||
const std::vector<BlobRange>& ranges, Tristate map,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
const MatPadding padding, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(mats.size() == ranges.size());
|
||||
|
||||
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
|
||||
|
|
@ -226,7 +245,7 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
|||
{
|
||||
PROFILER_ZONE("Startup.Weights.Allocate");
|
||||
// NOTE: this changes the stride of `mats`!
|
||||
mat_owners.AllocateFor(mats, padding, pool);
|
||||
AllocateAndBindAll(mats, padding, mat_owners, pool);
|
||||
}
|
||||
|
||||
const std::vector<IOBatch> batches =
|
||||
|
|
@ -259,7 +278,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
|||
});
|
||||
});
|
||||
|
||||
MapOrRead(mats, reader, ranges, map, mat_owners_, padding, pool);
|
||||
MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool);
|
||||
|
||||
Fixup(pool);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,11 +27,12 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // IsF32
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/model_store.h" // ModelStore
|
||||
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
||||
#include "io/blob_store.h" // BlobWriter
|
||||
#include "util/mat.h" // MatPtr
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/model_store.h" // ModelStore
|
||||
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
||||
#include "io/blob_store.h" // BlobWriter
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "util/mat.h" // MatPtr
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -191,7 +192,7 @@ struct LayerWeightsPtrs {
|
|||
MatPtrT<Weight> gating_einsum_w1;
|
||||
MatPtrT<Weight> gating_einsum_w2;
|
||||
MatPtrT<Weight> linear_w;
|
||||
// We don't yet have an RMSNorm that accepts all Weight.
|
||||
// > W8 is likely helpful.
|
||||
MatPtrT<WeightF32OrBF16> pre_attention_norm_scale;
|
||||
MatPtrT<WeightF32OrBF16> pre_ffw_norm_scale;
|
||||
MatPtrT<WeightF32OrBF16> post_attention_norm_scale;
|
||||
|
|
@ -299,17 +300,18 @@ struct LayerWeightsPtrs {
|
|||
|
||||
// Allocates memory for all the tensors in the layer. Note that this is slow
|
||||
// (non-parallel) and only used for a stand-alone layer.
|
||||
void AllocateForTest(MatOwners& mat_owners) {
|
||||
void AllocateForTest(std::vector<MatOwner>& mat_owners) {
|
||||
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
// `backprop/` does not use row accessors and hence requires kPacked.
|
||||
mat_owners.AllocateFor(t.mat, MatPadding::kPacked);
|
||||
mat_owners.push_back(MatOwner());
|
||||
mat_owners.back().AllocateFor(t.mat, MatPadding::kPacked);
|
||||
});
|
||||
}
|
||||
|
||||
// Must be called after reading weights via `ForEachTensor`.
|
||||
// TODO: exporters should bake this into the weights already.
|
||||
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||
void Fixup(MatOwners& mat_owners) {
|
||||
void Fixup(std::vector<MatOwner>& mat_owners) {
|
||||
InitAttWeights(mat_owners);
|
||||
SplitW1();
|
||||
SplitAttW1();
|
||||
|
|
@ -317,7 +319,7 @@ struct LayerWeightsPtrs {
|
|||
|
||||
private:
|
||||
// Copies att_weights from `attn_vec_einsum_w`.
|
||||
void InitAttWeights(MatOwners& mat_owners) {
|
||||
void InitAttWeights(std::vector<MatOwner>& mat_owners) {
|
||||
// We only use this tensor for Gemma layers.
|
||||
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||
|
||||
|
|
@ -343,7 +345,8 @@ struct LayerWeightsPtrs {
|
|||
{
|
||||
static std::mutex m;
|
||||
std::lock_guard<std::mutex> lock(m);
|
||||
mat_owners.AllocateFor(att_weights, MatPadding::kOdd);
|
||||
mat_owners.push_back(MatOwner());
|
||||
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd);
|
||||
}
|
||||
|
||||
const size_t T_bytes = att_weights.ElementBytes();
|
||||
|
|
@ -575,7 +578,8 @@ struct ModelWeightsPtrs {
|
|||
|
||||
// Instead of reading, only allocates memory for all tensors. Used by
|
||||
// `optimizer.cc` via the `Gemma` constructor without weights.
|
||||
void AllocateForTest(MatOwners& mat_owners, hwy::ThreadPool& pool) {
|
||||
void AllocateForTest(std::vector<MatOwner>& mat_owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
// First get a list of all the tensors.
|
||||
std::vector<MatPtr*> all_mat;
|
||||
all_mat.reserve(10 * c_layers.size());
|
||||
|
|
@ -583,14 +587,20 @@ struct ModelWeightsPtrs {
|
|||
all_mat.push_back(&t.mat);
|
||||
});
|
||||
|
||||
// `backprop/` does not use row accessors and hence requires kPacked.
|
||||
mat_owners.AllocateFor(all_mat, MatPadding::kPacked, pool);
|
||||
const size_t start = mat_owners.size();
|
||||
mat_owners.resize(start + all_mat.size());
|
||||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
pool.Run(0, all_mat.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
// `backprop/` does not use row accessors and hence requires kPacked.
|
||||
mat_owners[start + task].AllocateFor(*all_mat[task], MatPadding::kPacked);
|
||||
});
|
||||
}
|
||||
|
||||
// For reshaping file tensors to the shape expected by the code. This would
|
||||
// ideally already happen in the importer. Must be called after reading and
|
||||
// updating the attention weights.
|
||||
void Fixup(MatOwners& mat_owners, hwy::ThreadPool& pool) {
|
||||
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool) {
|
||||
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners);
|
||||
});
|
||||
|
|
@ -666,7 +676,7 @@ class WeightsOwner {
|
|||
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
||||
|
||||
// Owns the memory referenced by all `MatPtr`.
|
||||
MatOwners mat_owners_;
|
||||
std::vector<MatOwner> mat_owners_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -79,7 +79,6 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
|||
// M = A rows, K = A cols, N = C cols.
|
||||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
if (env.print_config || env.print_measurement) {
|
||||
fprintf(stderr, "\n");
|
||||
|
|
@ -92,8 +91,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
const Extents2D B_extents(N, K); // already transposed
|
||||
const Extents2D C_extents(M, N);
|
||||
|
||||
MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> c_slow_mat("c_slow_batch", C_extents, MatPadding::kOdd);
|
||||
MatStorageT<TC> c_mat("c_batch", C_extents, MatPadding::kOdd);
|
||||
|
||||
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
|
||||
if (add) {
|
||||
|
|
@ -105,7 +104,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C = RowPtrFromMat(c_batch);
|
||||
const RowPtr<TC> C = RowPtrFromMat(c_mat);
|
||||
|
||||
// Fewer reps for large batch sizes, which take longer.
|
||||
const size_t num_samples = M < 32 ? 20 : 12;
|
||||
|
|
@ -115,8 +114,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
// Ensure usage conditions are set before autotuning. Both binding and
|
||||
// spinning may materially affect the choice of config. No harm in calling
|
||||
// BindB/C if there is a single package: they will be a no-op.
|
||||
BindB(allocator, sizeof(TC), b_trans, env.parallel);
|
||||
BindC(allocator, A_extents.rows, C, env.parallel);
|
||||
BindB(b_trans, sizeof(TC), env.parallel);
|
||||
BindC(c_mat, env.parallel);
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
env.ctx.pools.MaybeStartSpinning(use_spinning);
|
||||
|
|
|
|||
|
|
@ -25,10 +25,12 @@
|
|||
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
#include "hwy/per_target.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -386,7 +388,7 @@ std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
|
|||
// rows for that.
|
||||
static size_t NPMultiple(const Allocator& allocator, size_t N,
|
||||
size_t sizeof_TC, size_t nr, size_t num_packages) {
|
||||
size_t np_multiple = allocator.QuantumBytes() / sizeof_TC;
|
||||
size_t np_multiple = allocator.BasePageBytes() / sizeof_TC;
|
||||
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
||||
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
|
||||
// choose a smaller multiple.
|
||||
|
|
@ -423,4 +425,59 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
|||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||
}
|
||||
|
||||
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||
Allocator& allocator = parallel.allocator();
|
||||
if (!allocator.ShouldBind()) return;
|
||||
if (B.Rows() == 1) return;
|
||||
|
||||
PROFILER_ZONE("Startup.BindB");
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
uintptr_t begin =
|
||||
reinterpret_cast<uintptr_t>(B.RowT<uint8_t>(rows_b.begin()));
|
||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
|
||||
// B row padding is less than the page size, so only bind the subset that
|
||||
// is page-aligned.
|
||||
begin = hwy::RoundUpTo(begin, allocator.BasePageBytes());
|
||||
end = hwy::RoundDownTo(end, allocator.BasePageBytes());
|
||||
if (HWY_LIKELY(begin != end)) {
|
||||
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// C is BF16/float, or double for partial
|
||||
void BindC(const MatPtr& C, MMParallel& parallel) {
|
||||
Allocator& allocator = parallel.allocator();
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
PROFILER_ZONE("Startup.BindC");
|
||||
|
||||
const IndexRangePartition ranges_np = parallel.RangesOfNP(
|
||||
MMParallel::kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
|
||||
bool ok = true;
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||
// `BindMemory` requires page alignment. These are in bytes.
|
||||
const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(),
|
||||
allocator.BasePageBytes());
|
||||
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
|
||||
allocator.BasePageBytes());
|
||||
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
for (size_t im = 0; im < C.Rows(); ++im) {
|
||||
ok &= allocator.BindMemory(C.MutableRowT<uint8_t>(im) + begin,
|
||||
end - begin, node);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(!ok)) {
|
||||
HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", C.Rows(), C.Cols(),
|
||||
ranges_np.NumTasks());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
59
ops/matmul.h
59
ops/matmul.h
|
|
@ -46,6 +46,8 @@ namespace gcpp {
|
|||
// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`.
|
||||
constexpr size_t kNR = 4;
|
||||
|
||||
// Mostly stateless, can be constructed on the fly by weights.cc, but captures
|
||||
// the singleton ThreadingContext to reduce MatMul call overhead.
|
||||
class MMParallel {
|
||||
public:
|
||||
static constexpr size_t kMaxPackages = 4;
|
||||
|
|
@ -55,6 +57,8 @@ class MMParallel {
|
|||
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
|
||||
}
|
||||
|
||||
Allocator& allocator() const { return ctx_.allocator; }
|
||||
|
||||
// Initial static partitioning of B rows across packages.
|
||||
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
||||
size_t sizeof_TC, size_t nr) const;
|
||||
|
|
@ -168,31 +172,9 @@ class MMParallel {
|
|||
ThreadingContext& ctx_;
|
||||
};
|
||||
|
||||
template <typename TC> // BF16/float for C, double for partial
|
||||
void BindC(const Allocator& allocator, size_t M, const RowPtr<TC>& C,
|
||||
MMParallel& parallel) {
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
|
||||
const size_t quantum = allocator.Quantum<TC>();
|
||||
bool ok = true;
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
for (size_t im = 0; im < M; ++im) {
|
||||
// `BindMemory` requires page alignment.
|
||||
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
|
||||
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
|
||||
ok &= allocator.BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
|
||||
node);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(!ok)) {
|
||||
HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", M, C.Cols(),
|
||||
ranges_np.NumTasks());
|
||||
}
|
||||
}
|
||||
void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
||||
// C is BF16/float, or double for partial.
|
||||
void BindC(const MatPtr& C, MMParallel& parallel);
|
||||
|
||||
// Per-package storage for packed A, and one global C-shaped `partial` for
|
||||
// accumulating partial dot products (sections of K).
|
||||
|
|
@ -227,7 +209,7 @@ class MMStorage {
|
|||
const size_t node = parallel.Node(pkg_idx);
|
||||
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
|
||||
pkg_A_[pkg_idx]->ElementBytes();
|
||||
bytes = hwy::RoundDownTo(bytes, allocator.QuantumBytes());
|
||||
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
|
||||
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
|
||||
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
||||
}
|
||||
|
|
@ -235,7 +217,7 @@ class MMStorage {
|
|||
});
|
||||
|
||||
// Avoid cross-package accesses.
|
||||
BindC(allocator, kMaxM, partial_, parallel);
|
||||
BindC(partial_storage_, parallel);
|
||||
}
|
||||
|
||||
// Returns per-package matrix view.
|
||||
|
|
@ -681,29 +663,6 @@ struct MMZone {
|
|||
};
|
||||
#endif // PROFILER_ENABLED
|
||||
|
||||
template <typename TB>
|
||||
void BindB(const Allocator& allocator, size_t sizeof_TC, const MatPtrT<TB>& B,
|
||||
MMParallel& parallel) {
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||
const size_t quantum = allocator.Quantum<TB>();
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
uintptr_t begin = reinterpret_cast<uintptr_t>(B.Row(rows_b.begin()));
|
||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB);
|
||||
// B is not yet guaranteed to have padded rows, so only bind the
|
||||
// subset that is page-aligned.
|
||||
begin = hwy::RoundUpTo(begin, quantum);
|
||||
end = hwy::RoundDownTo(end, quantum);
|
||||
if (HWY_LIKELY(begin != end)) {
|
||||
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_
|
||||
|
|
|
|||
|
|
@ -98,16 +98,12 @@ class Allocator {
|
|||
// = HWY_MAX(LineBytes(), VectorBytes())
|
||||
size_t StepBytes() const { return step_bytes_; }
|
||||
|
||||
// File size multiple required for memory mapping.
|
||||
// File size multiple required for memory mapping. Also used when binding
|
||||
// memory to NUMA nodes (see `BindB/BindC`).
|
||||
size_t BasePageBytes() const { return base_page_bytes_; }
|
||||
|
||||
// Either StepBytes or BasePageBytes if NUMA.
|
||||
// Desired allocator alignment: Either StepBytes, or BasePageBytes if NUMA.
|
||||
size_t QuantumBytes() const { return quantum_bytes_; }
|
||||
template <typename T>
|
||||
// For rounding down elements to the page size in `BindB/BindC`.
|
||||
size_t Quantum() const {
|
||||
return QuantumBytes() / sizeof(T);
|
||||
}
|
||||
|
||||
// L1 and L2 are typically per core.
|
||||
size_t L1Bytes() const { return l1_bytes_; }
|
||||
|
|
|
|||
13
util/mat.cc
13
util/mat.cc
|
|
@ -19,11 +19,9 @@
|
|||
#include <stdint.h>
|
||||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/per_target.h" // VectorBytes
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -126,15 +124,4 @@ void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
|||
mat.SetPtr(storage_.get(), stride);
|
||||
}
|
||||
|
||||
void MatOwners::AllocateFor(const std::vector<MatPtr*>& mats,
|
||||
MatPadding padding, hwy::ThreadPool& pool) {
|
||||
const size_t start = owners_.size();
|
||||
owners_.resize(start + mats.size());
|
||||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
owners_[start + task].AllocateFor(*mats[task], padding);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
23
util/mat.h
23
util/mat.h
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/types.h" // Type
|
||||
|
|
@ -32,7 +31,6 @@
|
|||
#include "util/basics.h" // Extents2D
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -376,31 +374,14 @@ class MatOwner {
|
|||
MatOwner& operator=(MatOwner&&) = default;
|
||||
|
||||
// Allocates the type/extents indicated by `mat` and sets its pointer.
|
||||
// Ignores `padding` for NUQ tensors, which are always packed.
|
||||
// Thread-compatible, weights are allocated in parallel.
|
||||
void AllocateFor(MatPtr& mat, MatPadding padding);
|
||||
|
||||
private:
|
||||
AlignedPtr<uint8_t[]> storage_;
|
||||
};
|
||||
|
||||
// Multiple `MatOwner`, with support for parallel allocation.
|
||||
class MatOwners {
|
||||
public:
|
||||
// Ignores `padding` for NUQ tensors, which are always packed.
|
||||
void AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
|
||||
owners_.push_back(MatOwner());
|
||||
owners_.back().AllocateFor(mat, padding);
|
||||
}
|
||||
|
||||
// Allocates multiple in parallel. Ignores `padding` for NUQ tensors,
|
||||
// which are always packed.
|
||||
void AllocateFor(const std::vector<MatPtr*>& mats, MatPadding padding,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
private:
|
||||
std::vector<MatOwner> owners_;
|
||||
};
|
||||
|
||||
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
|
||||
// tests to allocate and access tensors of a known type. By contrast, the
|
||||
// heterogeneous model weights are owned by vectors of `MatOwner`.
|
||||
|
|
|
|||
Loading…
Reference in New Issue