De-singleton ThreadingContext so callers can pass in their own

weights.cc: fix BindB argument for bf16 tensors
threading_test: enable autotune
PiperOrigin-RevId: 785763618
This commit is contained in:
Jan Wassenberg 2025-07-22 02:07:58 -07:00 committed by Copybara-Service
parent 5474146129
commit e76e29ce11
37 changed files with 332 additions and 337 deletions

View File

@ -389,6 +389,7 @@ cc_test(
deps = [ deps = [
":mat", ":mat",
":ops", ":ops",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress", "//compression:compress",
"@highway//:hwy", "@highway//:hwy",

View File

@ -57,6 +57,9 @@ class SbsWriterImpl : public ISbsWriter {
template <typename Packed> template <typename Packed>
void InsertT(const char* name, F32Span weights, void InsertT(const char* name, F32Span weights,
const TensorInfo& tensor_info) { const TensorInfo& tensor_info) {
// TODO(janwas): 1D parallel-for.
hwy::ThreadPool& pool = ctx_.pools.Pool();
MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info)); MatPtrT<Packed> mat(name, ExtentsFromInfo(&tensor_info));
// SFP and NUQ (which uses SFP for cluster centers) have a limited range // SFP and NUQ (which uses SFP for cluster centers) have a limited range
// and depending on the input values may require rescaling. Scaling is // and depending on the input values may require rescaling. Scaling is
@ -73,13 +76,13 @@ class SbsWriterImpl : public ISbsWriter {
mat.AppendTo(serialized_mat_ptrs_); mat.AppendTo(serialized_mat_ptrs_);
mat_owners_.push_back(MatOwner()); mat_owners_.push_back(MatOwner());
mat_owners_.back().AllocateFor(mat, MatPadding::kPacked); mat_owners_.back().AllocateFor(mat, ctx_.allocator, MatPadding::kPacked);
// Handle gemma_export_test's MockArray. Write blobs so that the test // Handle gemma_export_test's MockArray. Write blobs so that the test
// succeeds, but we only have 10 floats, not the full tensor. // succeeds, but we only have 10 floats, not the full tensor.
if (weights.size() == 10 && mat.Extents().Area() != 10) { if (weights.size() == 10 && mat.Extents().Area() != 10) {
Compress(weights.data(), weights.size(), working_set_, mat.Span(), Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool_); /*packed_ofs=*/0, pool);
writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10); writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10);
return; return;
} }
@ -89,12 +92,12 @@ class SbsWriterImpl : public ISbsWriter {
TypeName(TypeEnum<Packed>())); TypeName(TypeEnum<Packed>()));
HWY_ASSERT(weights.size() == mat.Extents().Area()); HWY_ASSERT(weights.size() == mat.Extents().Area());
Compress(weights.data(), weights.size(), working_set_, mat.Span(), Compress(weights.data(), weights.size(), working_set_, mat.Span(),
/*packed_ofs=*/0, pool_); /*packed_ofs=*/0, pool);
writer_.Add(name, mat.Packed(), mat.PackedBytes()); writer_.Add(name, mat.Packed(), mat.PackedBytes());
} }
public: public:
SbsWriterImpl() : pool_(ThreadingContext::Get().pools.Pool()) {} SbsWriterImpl() : ctx_(ThreadingArgs()) {}
void Insert(const char* name, F32Span weights, Type type, void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override { const TensorInfo& tensor_info) override {
@ -122,18 +125,18 @@ class SbsWriterImpl : public ISbsWriter {
const GemmaTokenizer tokenizer( const GemmaTokenizer tokenizer(
tokenizer_path.empty() ? kMockTokenizer tokenizer_path.empty() ? kMockTokenizer
: ReadFileToString(Path(tokenizer_path))); : ReadFileToString(Path(tokenizer_path)));
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_, pool_, WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_,
gcpp::Path(path)); ctx_.pools.Pool(), gcpp::Path(path));
} }
hwy::ThreadPool& pool_; ThreadingContext ctx_;
std::vector<MatOwner> mat_owners_; std::vector<MatOwner> mat_owners_;
CompressWorkingSet working_set_; CompressWorkingSet working_set_;
BlobWriter writer_; BlobWriter writer_;
std::vector<uint32_t> serialized_mat_ptrs_; std::vector<uint32_t> serialized_mat_ptrs_;
}; };
ISbsWriter* NewSbsWriter() { return new SbsWriterImpl; } ISbsWriter* NewSbsWriter() { return new SbsWriterImpl(); }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp

View File

@ -69,12 +69,13 @@ void ForeachPackedAndRawType() {
// Generates inputs: deterministic, within max SfpStream range. // Generates inputs: deterministic, within max SfpStream range.
template <typename MatT> template <typename MatT>
MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding, MatStorageT<MatT> GenerateMat(const Extents2D& extents,
const Allocator& allocator, MatPadding padding,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers()); ws.tls.resize(pool.NumWorkers());
MatStorageT<float> raw("raw", extents, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, padding); MatStorageT<MatT> compressed("mat", extents, allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r); float* HWY_RESTRICT row = raw.Row(r);
@ -95,12 +96,13 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
// Same, but `extents` describes the transposed matrix. // Same, but `extents` describes the transposed matrix.
template <typename MatT> template <typename MatT>
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents, MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
const Allocator& allocator,
MatPadding padding, MatPadding padding,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
ws.tls.resize(pool.NumWorkers()); ws.tls.resize(pool.NumWorkers());
MatStorageT<float> raw("raw", extents, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, padding); MatStorageT<MatT> compressed("trans", extents, allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { pool.Run(0, extents.rows, [&](const size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r); float* HWY_RESTRICT row = raw.Row(r);

View File

@ -74,7 +74,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens); size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos, std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens); prompt.begin() + pos + num_tokens);
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference()); KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference(),
env.MutableEnv().ctx.allocator);
float entropy = float entropy =
ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache, ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache,
env.MutableEnv(), env.Verbosity()); env.MutableEnv(), env.Verbosity());

View File

@ -50,14 +50,13 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading, inference)), : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) {
gemma_(loader, inference, env_.ctx.pools) {
const ModelConfig& config = gemma_.GetModelConfig(); const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference)); kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
if (inference.verbosity >= 2) { if (inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config); ShowConfig(loader, threading, inference, config, ctx_);
} }
InitGenerator(inference, gen_); InitGenerator(inference, gen_);
@ -141,7 +140,8 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
// Ensure we have at least one KVCache per query. // Ensure we have at least one KVCache per query.
while (kv_caches_.size() < num_queries) { while (kv_caches_.size() < num_queries) {
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference())); kv_caches_.push_back(
KVCache(gemma_.GetModelConfig(), gemma_.Inference(), ctx_.allocator));
} }
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries); const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
@ -228,7 +228,8 @@ static constexpr const char* CompiledConfig() {
} }
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config) { const InferenceArgs& inference, const ModelConfig& config,
const ThreadingContext& ctx) {
threading.Print(inference.verbosity); threading.Print(inference.verbosity);
loader.Print(inference.verbosity); loader.Print(inference.verbosity);
inference.Print(inference.verbosity); inference.Print(inference.verbosity);
@ -241,7 +242,6 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
char* dt = ctime(&now); // NOLINT char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown"; char cpu100[100] = "unknown";
(void)hwy::platform::GetCpuString(cpu100); (void)hwy::platform::GetCpuString(cpu100);
const ThreadingContext& ctx = ThreadingContext::Get();
fprintf(stderr, fprintf(stderr,
"Date & Time : %s" // dt includes \n "Date & Time : %s" // dt includes \n

View File

@ -49,9 +49,6 @@ class GemmaEnv {
GemmaEnv(int argc, char** argv); GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference); const InferenceArgs& inference);
// Avoid memory leaks in test.
~GemmaEnv() { ThreadingContext::ThreadHostileInvalidate(); }
MatMulEnv& Env() { return env_; } MatMulEnv& Env() { return env_; }
size_t MaxGeneratedTokens() const { size_t MaxGeneratedTokens() const {
@ -115,6 +112,7 @@ class GemmaEnv {
MatMulEnv& MutableEnv() { return env_; } MatMulEnv& MutableEnv() { return env_; }
private: private:
ThreadingContext ctx_;
MatMulEnv env_; MatMulEnv env_;
Gemma gemma_; Gemma gemma_;
std::mt19937 gen_; // Random number generator. std::mt19937 gen_; // Random number generator.
@ -126,7 +124,8 @@ class GemmaEnv {
void LogSpeedStats(double time_start, size_t total_tokens); void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config); const InferenceArgs& inference, const ModelConfig& config,
const ThreadingContext& ctx);
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference); const InferenceArgs& inference);

View File

@ -51,9 +51,10 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference)); gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference));
gcpp::Gemma gemma(loader, inference, env.ctx.pools); gcpp::MatMulEnv env(ctx);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference); gcpp::Gemma gemma(loader, inference, ctx);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
size_t generated = 0; size_t generated = 0;
// Initialize random number generator // Initialize random number generator

View File

@ -35,9 +35,10 @@ class SimplifiedGemma {
SimplifiedGemma(const gcpp::LoaderArgs& loader, SimplifiedGemma(const gcpp::LoaderArgs& loader,
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: env_(MakeMatMulEnv(threading, inference)), : ctx_(UpdateArgs(threading, inference)),
gemma_(loader, inference, env_.ctx.pools), env_(ctx_),
kv_cache_(gemma_.GetModelConfig(), inference) { gemma_(loader, inference, ctx_),
kv_cache_(gemma_.GetModelConfig(), inference, ctx_.allocator) {
// Initialize random number generator // Initialize random number generator
std::random_device rd; std::random_device rd;
gen_.seed(rd()); gen_.seed(rd());
@ -88,6 +89,7 @@ class SimplifiedGemma {
~SimplifiedGemma() = default; ~SimplifiedGemma() = default;
private: private:
gcpp::ThreadingContext ctx_;
gcpp::MatMulEnv env_; gcpp::MatMulEnv env_;
gcpp::Gemma gemma_; gcpp::Gemma gemma_;
gcpp::KVCache kv_cache_; gcpp::KVCache kv_cache_;

View File

@ -35,13 +35,15 @@ namespace gcpp {
struct GriffinActivations { struct GriffinActivations {
GriffinActivations(const ModelConfig& config, size_t batch_size, GriffinActivations(const ModelConfig& config, size_t batch_size,
MatPadding pad) const Allocator& allocator)
: griffin_x("griffin_x", Extents2D(batch_size, config.model_dim), pad), : griffin_x(
griffin_y("griffin_y", Extents2D(batch_size, config.model_dim), pad), MatFactory("griffin_x", batch_size, config.model_dim, allocator)),
griffin_gate_x("griffin_gate_x", griffin_y(
Extents2D(batch_size, config.model_dim), pad), MatFactory("griffin_y", batch_size, config.model_dim, allocator)),
griffin_multiplier("griffin_mul", griffin_gate_x(MatFactory("griffin_gate_x", batch_size,
Extents2D(batch_size, config.model_dim), pad) {} config.model_dim, allocator)),
griffin_multiplier(MatFactory("griffin_mul", batch_size,
config.model_dim, allocator)) {}
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
if (griffin_x.Rows() == 0) return; if (griffin_x.Rows() == 0) return;
@ -70,34 +72,34 @@ struct AttentionActivations {
AttentionActivations( AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config, const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, MatPadding pad, size_t batch_size, size_t seq_len, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: config(config), : config(config),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache. // and does not use an external KV cache.
q("q", q(MatFactory("q", batch_size,
Extents2D(batch_size,
config.vocab_size == 0 config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim ? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim), : layer_config.heads * layer_config.qkv_dim,
pad), allocator)),
pre_att_rms_out("pre_att_rms_out", pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
Extents2D(batch_size, config.model_dim), pad), config.model_dim, allocator)),
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad), att(MatFactory("att", batch_size, layer_config.heads * seq_len,
att_out( allocator)),
"att_out", att_out(MatFactory("att_out", batch_size,
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), layer_config.heads * layer_config.qkv_dim,
pad), allocator)),
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad), att_sums(
MatFactory("att_sums", batch_size, config.model_dim, allocator)),
inv_timescale( inv_timescale(
CreateInvTimescale(layer_config.qkv_dim, CreateInvTimescale(allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope)), layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale( inv_timescale_global(CreateInvTimescale(
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, allocator, layer_config.qkv_dim,
1000000.0)), layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
div_seq_len(static_cast<uint32_t>(seq_len)), div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)), div_heads(static_cast<uint32_t>(layer_config.heads)),
@ -149,21 +151,23 @@ struct AttentionActivations {
struct Activations { struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]), : layer_config(config.layer_configs[0]),
x("x", Extents2D(batch_size, config.model_dim), pad_), x(MatFactory("x", batch_size, config.model_dim, allocator)),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_), logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
pre_ffw_rms_out("pre_ffw_rms_out", pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
Extents2D(batch_size, config.model_dim), pad_), config.model_dim, allocator)),
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)),
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_), ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)),
attention(config, layer_config, batch_size, seq_len, pad_, row_ptrs), attention(config, layer_config, batch_size, seq_len, allocator,
row_ptrs),
griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0,
pad_) { allocator) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers. // For MatMul outputs, precompute their row pointers.
@ -193,15 +197,14 @@ struct Activations {
} }
const LayerConfig& layer_config; const LayerConfig& layer_config;
const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd;
MatStorageT<float> x; // input MatStorageT<float> x; // input
MatStorageT<float> logits; MatStorageT<float> logits;
// Gated FFW // Gated FFW
MatStorageT<BF16> pre_ffw_rms_out; MatStorageT<BF16> pre_ffw_rms_out;
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it // Norm may be large, so prefer to keep as f32.
MatStorageT<float> C1;
MatStorageT<float> C2; MatStorageT<float> C2;
MatStorageT<BF16> ffw_out; MatStorageT<BF16> ffw_out;

View File

@ -43,8 +43,10 @@ namespace gcpp {
// ConversationData constructor implementation // ConversationData constructor implementation
ConversationData::ConversationData(const ModelConfig& model_config, ConversationData::ConversationData(const ModelConfig& model_config,
const InferenceArgs& inference_args) const InferenceArgs& inference_args,
: kv_cache(std::make_unique<KVCache>(model_config, inference_args)), const Allocator& allocator)
: kv_cache(
std::make_unique<KVCache>(model_config, inference_args, allocator)),
abs_pos(0) {} abs_pos(0) {}
// ConversationData copy constructor implementation // ConversationData copy constructor implementation
@ -101,15 +103,16 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
int max_generated_tokens) int max_generated_tokens)
: inference_args(inference_args), : inference_args(inference_args),
threading_args(threading_args), threading_args(threading_args),
matmul_env(MakeMatMulEnv(threading_args, inference_args)), ctx(UpdateArgs(threading_args, inference_args)),
matmul_env(ctx),
active_conversation_name("default"), active_conversation_name("default"),
model(loader, inference_args, matmul_env.ctx.pools) { model(loader, inference_args, matmul_env.ctx) {
std::stringstream ss; std::stringstream ss;
LogDebug("Creating initial ConversationData"); LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared // Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>( active_conversation = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args); model.GetModelConfig(), inference_args, ctx.allocator);
LogDebug( LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]"); "Storing initial ConversationData in conversation_cache[\"default\"]");
@ -188,7 +191,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim), ? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
model_config.model_dim) model_config.model_dim)
: Extents2D(0, 0), : Extents2D(0, 0),
MatPadding::kOdd); ctx.allocator, MatPadding::kOdd);
if (image_data != nullptr) { if (image_data != nullptr) {
HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA || HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
model_config.wrapping == PromptWrapping::GEMMA_VLM); model_config.wrapping == PromptWrapping::GEMMA_VLM);

View File

@ -41,7 +41,8 @@ namespace gcpp {
// Struct to hold data for a single conversation thread // Struct to hold data for a single conversation thread
struct ConversationData { struct ConversationData {
ConversationData(const ModelConfig& model_config, ConversationData(const ModelConfig& model_config,
const InferenceArgs& inference_args); const InferenceArgs& inference_args,
const Allocator& allocator);
ConversationData(const ConversationData& other); ConversationData(const ConversationData& other);
std::unique_ptr<KVCache> kv_cache; std::unique_ptr<KVCache> kv_cache;
@ -178,8 +179,8 @@ class GemmaContext {
// rewind to initial state. // rewind to initial state.
active_conversation->abs_pos = 0; active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object // Replace the cache within the current ConversationData object
active_conversation->kv_cache = active_conversation->kv_cache = std::make_unique<KVCache>(
std::make_unique<KVCache>(model.GetModelConfig(), inference_args); model.GetModelConfig(), inference_args, ctx.allocator);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else { } else {
@ -197,7 +198,7 @@ class GemmaContext {
LogDebug("Creating new conversation"); LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared // Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>( conversation_cache[name] = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args); model.GetModelConfig(), inference_args, ctx.allocator);
return true; return true;
} }
@ -280,6 +281,7 @@ class GemmaContext {
// Cached args (remain global for the context) // Cached args (remain global for the context)
InferenceArgs inference_args; InferenceArgs inference_args;
ThreadingArgs threading_args; ThreadingArgs threading_args;
ThreadingContext ctx;
MatMulEnv matmul_env; MatMulEnv matmul_env;
std::string active_conversation_name; std::string active_conversation_name;

View File

@ -537,7 +537,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const WeightsPtrs& weights, KVCache& kv_cache, const WeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) { MatMulEnv& env, TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size, Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.row_ptrs); kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs);
AllQueries all_queries(prompt, pos, prefix_end, AllQueries all_queries(prompt, pos, prefix_end,
hwy::Span<KVCache>(&kv_cache, 1)); hwy::Span<KVCache>(&kv_cache, 1));
@ -555,7 +555,8 @@ void GenerateBatchT(const ModelConfig& config,
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size); runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, Activations activations(config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.row_ptrs); all_queries[0].kv_cache.SeqLen(), env.ctx.allocator,
env.row_ptrs);
for (size_t start = 0; start < all_queries.NumQueries(); for (size_t start = 0; start < all_queries.NumQueries();
start += runtime_config.decode_qbatch_size) { start += runtime_config.decode_qbatch_size) {
@ -579,7 +580,7 @@ void GenerateImageTokensT(const ModelConfig& config,
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim); num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, num_tokens, Activations prefill_activations(vit_config, num_tokens, num_tokens,
env.row_ptrs); env.ctx.allocator, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env); prefill_activations, env);
@ -596,28 +597,14 @@ HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT); HWY_EXPORT(GenerateImageTokensT);
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args) {
if (inference_args.decode_qbatch_size >= 256) {
ThreadingArgs copy = threading_args;
copy.max_packages = 1;
ThreadingContext::SetArgs(copy);
} else {
ThreadingContext::SetArgs(threading_args);
}
return MatMulEnv(ThreadingContext::Get());
}
Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
NestedPools& pools) ThreadingContext& ctx)
: reader_(loader.weights), : reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping), model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()), weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model), chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference) { inference_(inference) {
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx);
pools.Pool());
reader_.CloseFile(); reader_.CloseFile();
} }

View File

@ -225,18 +225,15 @@ struct TimingInfo {
size_t tokens_generated = 0; size_t tokens_generated = 0;
}; };
// Returns the `MatMulEnv` after calling `SetArgs`.
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args);
// After construction, all methods are const and thread-compatible if using // After construction, all methods are const and thread-compatible if using
// separate MatMulEnv for each thread. // separate ThreadingContext for each thread.
class Gemma { class Gemma {
public: public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// `pools` are used to parallelize loading. // `ctx` is only used to read tensors, but it is typically also referenced
// by the `MatMulEnv` passed to the Generate* methods.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
NestedPools& pools); ThreadingContext& ctx);
~Gemma(); ~Gemma();
// TODO: rename to Config() // TODO: rename to Config()

View File

@ -256,6 +256,16 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
} }
}; };
static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args,
const InferenceArgs& inference_args) {
if (inference_args.decode_qbatch_size >= 256) {
ThreadingArgs copy = threading_args;
copy.max_packages = 1;
return copy;
}
return threading_args;
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_

View File

@ -57,20 +57,24 @@ static size_t CappedSeqLen(const ModelConfig& config,
} }
KVCache::KVCache(const Extents2D& conv1d_extents, KVCache::KVCache(const Extents2D& conv1d_extents,
const Extents2D& rglru_extents, const Extents2D& kv_extents) const Extents2D& rglru_extents, const Extents2D& kv_extents,
: conv1d_cache("conv1d_cache", conv1d_extents, MatPadding::kOdd), const Allocator& allocator)
rglru_cache("rglru_cache", rglru_extents, MatPadding::kOdd), : conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd),
kv_cache("kv", kv_extents, MatPadding::kOdd) {} rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd),
kv_cache("kv", kv_extents, allocator, MatPadding::kOdd),
allocator_(allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args) KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
: KVCache(Extents2D(GriffinLayers(config), GriffinConv1dCols(config)), const Allocator& allocator)
: KVCache(
Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
Extents2D(GriffinLayers(config), config.model_dim), Extents2D(GriffinLayers(config), config.model_dim),
Extents2D(CappedSeqLen(config, inference_args), Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
config.KVCacheCols())) {} allocator) {}
KVCache KVCache::Copy() { KVCache KVCache::Copy() {
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(), KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
kv_cache.Extents()); kv_cache.Extents(), allocator_);
if (conv1d_cache.Rows() != 0) { if (conv1d_cache.Rows() != 0) {
CopyMat(conv1d_cache, copy.conv1d_cache); CopyMat(conv1d_cache, copy.conv1d_cache);

View File

@ -28,7 +28,8 @@ namespace gcpp {
using KV_t = float; using KV_t = float;
struct KVCache { struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args); KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of // Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit. // copy ctor to make the cost explicit.
@ -47,9 +48,11 @@ struct KVCache {
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private: private:
const Allocator& allocator_;
// For use by other ctor and Copy() // For use by other ctor and Copy()
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents, KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
const Extents2D& kv_extents); const Extents2D& kv_extents, const Allocator& allocator);
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -110,7 +110,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim), have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
config.model_dim) config.model_dim)
: Extents2D(0, 0), : Extents2D(0, 0),
MatPadding::kOdd); env.ctx.allocator, MatPadding::kOdd);
image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs); image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs);
if (have_image) { if (have_image) {
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA || HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
@ -254,10 +254,11 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) { const InferenceArgs& inference) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
MatMulEnv env(MakeMatMulEnv(threading, inference)); ThreadingContext ctx(UpdateArgs(threading, inference));
MatMulEnv env(ctx);
if (inference.verbosity >= 2) env.print_best = true; if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, inference, env.ctx.pools); const Gemma gemma(loader, inference, ctx);
KVCache kv_cache(gemma.GetModelConfig(), inference); KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator);
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
std::string instructions = std::string instructions =
@ -284,7 +285,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
if (inference.IsInteractive()) { if (inference.IsInteractive()) {
std::cout << "\033[2J\033[1;1H" // clear screen std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n"; << kAsciiArtBanner << "\n\n";
ShowConfig(loader, threading, inference, gemma.GetModelConfig()); ShowConfig(loader, threading, inference, gemma.GetModelConfig(), ctx);
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
} }

View File

@ -80,11 +80,11 @@ class VitAttention {
// Shift Q, K, VT to MatStorageT. // Shift Q, K, VT to MatStorageT.
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim), MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
MatPadding::kPacked); env_.ctx.allocator, MatPadding::kPacked);
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim), MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator,
MatPadding::kPacked); MatPadding::kPacked);
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len), MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
MatPadding::kPacked); env_.ctx.allocator, MatPadding::kPacked);
// Initialize att_out to zero prior to head loop. // Initialize att_out to zero prior to head loop.
ZeroInit(activations_.attention.att_out); ZeroInit(activations_.attention.att_out);
@ -294,7 +294,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
// image_patches is (256, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3)
// Must be padded, see `DoDecompressA`. // Must be padded, see `DoDecompressA`.
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size), MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
MatPadding::kOdd); env.ctx.allocator, MatPadding::kOdd);
for (size_t i = 0; i < num_tokens; ++i) { for (size_t i = 0; i < num_tokens; ++i) {
image.GetPatch(i, image_patches.Row(i)); image.GetPatch(i, image_patches.Row(i));
} }
@ -329,7 +329,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
weights.vit_encoder_norm_bias, activations.x); weights.vit_encoder_norm_bias, activations.x);
if (model_config.wrapping == PromptWrapping::GEMMA_VLM) { if (model_config.wrapping == PromptWrapping::GEMMA_VLM) {
activations.x = AvgPool4x4(activations.x); activations.x = AvgPool4x4(activations.x, env.ctx.allocator);
// Apply soft embedding norm before input projection. // Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {

View File

@ -44,7 +44,8 @@
namespace gcpp { namespace gcpp {
// Copies att_weights from `attn_vec_einsum_w`. // Copies att_weights from `attn_vec_einsum_w`.
void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners) { void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// We only use this tensor for Gemma layers. // We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return; if (layer_config.type != LayerAttentionType::kGemma) return;
@ -71,7 +72,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners) {
static std::mutex m; static std::mutex m;
std::lock_guard<std::mutex> lock(m); std::lock_guard<std::mutex> lock(m);
mat_owners.push_back(MatOwner()); mat_owners.push_back(MatOwner());
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd); mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kOdd);
} }
const size_t T_bytes = att_weights.ElementBytes(); const size_t T_bytes = att_weights.ElementBytes();
@ -149,9 +150,10 @@ void LayerWeightsPtrs::SplitAttW1() {
// Must be called after reading weights via `ForEachTensor`. // Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already. // TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock. // WARNING: called from multiple threads; `mat_owners` requires a lock.
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners) { void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
const Allocator& allocator) {
// TODO(janwas): handle NUQ // TODO(janwas): handle NUQ
InitAttWeights(mat_owners); InitAttWeights(mat_owners, allocator);
SplitW1(); SplitW1();
SplitAttW1(); SplitAttW1();
} }
@ -223,13 +225,15 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
// For reshaping file tensors to the shape expected by the code. This would // For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by WeightsOwner::Fixup. // ideally already happen in the importer. Called by WeightsOwner::Fixup.
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners, void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) { ThreadingContext& ctx) {
// TODO: use 1D parallel-for helper function
hwy::ThreadPool& pool = ctx.pools.Pool();
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Fixup(mat_owners); GetLayer(layer)->Fixup(mat_owners, ctx.allocator);
}); });
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
VitLayer(layer)->Fixup(mat_owners); VitLayer(layer)->Fixup(mat_owners, ctx.allocator);
}); });
} }
@ -260,12 +264,12 @@ enum class Mode {
// Decides whether to read or map based on heuristics and user override. // Decides whether to read or map based on heuristics and user override.
static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader, static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader,
const InferenceArgs& inference) { const InferenceArgs& inference,
const Allocator& allocator) {
Tristate to_bf16 = loader.to_bf16; Tristate to_bf16 = loader.to_bf16;
Tristate map = loader.map; Tristate map = loader.map;
// Disable mapping if not padded to the base page size. // Disable mapping if not padded to the base page size.
const Allocator& allocator = ThreadingContext::Get().allocator;
if (file_bytes % allocator.BasePageBytes() != 0) { if (file_bytes % allocator.BasePageBytes() != 0) {
if (map == Tristate::kTrue) { // Only complain if explicitly requested. if (map == Tristate::kTrue) { // Only complain if explicitly requested.
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.", HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
@ -321,14 +325,15 @@ struct TensorToRead {
// Allocates multiple in parallel and binds to NUMA nodes. // Allocates multiple in parallel and binds to NUMA nodes.
static void AllocateAndBindAll(std::vector<TensorToRead>& tensors, static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
const Mode mode, std::vector<MatOwner>& owners, const Mode mode, std::vector<MatOwner>& owners,
hwy::ThreadPool& pool) { ThreadingContext& ctx) {
const size_t start = owners.size(); const size_t start = owners.size();
owners.resize(start + tensors.size()); owners.resize(start + tensors.size());
MMParallel parallel(ThreadingContext::Get()); MMParallel parallel(ctx);
// Allocate in parallel because faulting in large tensors is slow. // Allocate in parallel because faulting in large tensors is slow.
pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { ctx.pools.Pool().Run(
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
TensorToRead& tensor = tensors[task]; TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat; MatPtr& mat = *tensor.mat;
@ -341,9 +346,9 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
mat.SetType(Type::kBF16); mat.SetType(Type::kBF16);
} }
owners[start + task].AllocateFor(*tensor.mat, tensor.padding); owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
// TODO(janwas): MatMul outputs will later also be BF16. tensor.padding);
BindB(*tensor.mat, sizeof(float), parallel); BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel);
}); });
} }
@ -482,7 +487,7 @@ static void ReadBatches(const BlobReader& reader,
// Aborts on error. // Aborts on error.
static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader, static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
Mode mode, std::vector<MatOwner>& mat_owners, Mode mode, std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) { ThreadingContext& ctx) {
if (mode == Mode::kMap) { if (mode == Mode::kMap) {
MapPtr mapped = reader.file().Map(); MapPtr mapped = reader.file().Map();
if (mapped) return MapAll(tensors, mapped); if (mapped) return MapAll(tensors, mapped);
@ -496,9 +501,11 @@ static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
{ {
PROFILER_ZONE("Startup.Weights.Allocate"); PROFILER_ZONE("Startup.Weights.Allocate");
// NOTE: this changes the stride of `mats`! // NOTE: this changes the stride of `mats`!
AllocateAndBindAll(tensors, mode, mat_owners, pool); AllocateAndBindAll(tensors, mode, mat_owners, ctx);
} }
hwy::ThreadPool& pool = ctx.pools.Pool();
if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool); if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool);
const std::vector<IOBatch> batches = const std::vector<IOBatch> batches =
@ -510,7 +517,7 @@ void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
const LoaderArgs& loader, const LoaderArgs& loader,
const InferenceArgs& inference, const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners, std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) { ThreadingContext& ctx) {
// List of tensors to read/map, and where from. // List of tensors to read/map, and where from.
std::vector<TensorToRead> tensors; std::vector<TensorToRead> tensors;
@ -529,13 +536,14 @@ void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name()); HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
}); });
const Mode mode = ChooseMode(reader.file_bytes(), loader, inference); const Mode mode =
ChooseMode(reader.file_bytes(), loader, inference, ctx.allocator);
MapOrReadAll(tensors, reader, mode, mat_owners, pool); MapOrReadAll(tensors, reader, mode, mat_owners, ctx);
{ {
PROFILER_ZONE("Startup.Fixup"); PROFILER_ZONE("Startup.Fixup");
Fixup(mat_owners, pool); Fixup(mat_owners, ctx);
} }
} }

View File

@ -29,7 +29,7 @@
#include "gemma/tensor_info.h" // TensorInfoRegistry #include "gemma/tensor_info.h" // TensorInfoRegistry
#include "io/blob_store.h" // BlobWriter #include "io/blob_store.h" // BlobWriter
#include "util/mat.h" // MatPtr #include "util/mat.h" // MatPtr
#include "hwy/contrib/thread_pool/thread_pool.h" #include "util/threading_context.h"
namespace gcpp { namespace gcpp {
@ -299,11 +299,12 @@ struct LayerWeightsPtrs {
// Must be called after reading weights via `ForEachTensor`. // Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already. // TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock. // WARNING: called from multiple threads; `mat_owners` requires a lock.
void Fixup(std::vector<MatOwner>& mat_owners); void Fixup(std::vector<MatOwner>& mat_owners, const Allocator& allocator);
private: private:
// Copies att_weights from `attn_vec_einsum_w`. // Copies att_weights from `attn_vec_einsum_w`.
void InitAttWeights(std::vector<MatOwner>& mat_owners); void InitAttWeights(std::vector<MatOwner>& mat_owners,
const Allocator& allocator);
// For FFN. Fast, only updates pointers. // For FFN. Fast, only updates pointers.
void SplitW1(); void SplitW1();
@ -426,7 +427,7 @@ struct WeightsPtrs {
// override for whether to map blobs or read them. // override for whether to map blobs or read them.
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
const LoaderArgs& loader, const InferenceArgs& inference, const LoaderArgs& loader, const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool); std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
// Adds one blob for each tensor's data and returns all serialized MatPtr. // Adds one blob for each tensor's data and returns all serialized MatPtr.
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const; std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
@ -434,7 +435,7 @@ struct WeightsPtrs {
private: private:
// For reshaping file tensors to the shape expected by the code. This would // For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by ReadFromBlobs. // ideally already happen in the importer. Called by ReadFromBlobs.
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool); void Fixup(std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
}; // `WeightsPtrs` }; // `WeightsPtrs`
#undef TENSOR_ARGS #undef TENSOR_ARGS

View File

@ -227,11 +227,12 @@ void ReadAndCompareBlobs(const Path& path1, const Path& path2) {
BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos); BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos);
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos); BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
NestedPools& pools = ThreadingContext::Get().pools; ThreadingArgs args;
ThreadingContext ctx(args);
ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2, ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
pools); ctx.pools);
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools); CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx.pools);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -36,7 +36,9 @@ class BlobStoreTest : public testing::Test {};
#endif #endif
TEST(BlobStoreTest, TestReadWrite) { TEST(BlobStoreTest, TestReadWrite) {
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool(); ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828}; static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
@ -92,7 +94,9 @@ TEST(BlobStoreTest, TestReadWrite) {
// Ensures padding works for any number of random-sized blobs. // Ensures padding works for any number of random-sized blobs.
TEST(BlobStoreTest, TestNumBlobs) { TEST(BlobStoreTest, TestNumBlobs) {
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool(); ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
hwy::RandomState rng; hwy::RandomState rng;
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) { for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {

View File

@ -84,19 +84,22 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const Extents2D B_extents(N, K); // already transposed const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N); const Extents2D C_extents(M, N);
MatStorageT<TC> C_slow("c_slow_batch", C_extents, MatPadding::kOdd); MatStorageT<TC> C_slow("c_slow_batch", C_extents, env.ctx.allocator,
MatStorageT<TC> C("c_batch", C_extents, MatPadding::kOdd); MatPadding::kOdd);
MatStorageT<TC> C("c_batch", C_extents, env.ctx.allocator, MatPadding::kOdd);
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked); MatStorageT<float> add_storage("add", Extents2D(), env.ctx.allocator,
MatPadding::kPacked);
if (add) { if (add) {
add_storage = add_storage = GenerateMat<float>(Extents2D(1, N), env.ctx.allocator,
GenerateMat<float>(Extents2D(1, N), MatPadding::kPacked, pool); MatPadding::kPacked, pool);
add_storage.SetScale(1.0f); add_storage.SetScale(1.0f);
} }
MatStorageT<TA> a = GenerateMat<TA>(A_extents, MatPadding::kOdd, pool); MatStorageT<TA> a =
MatStorageT<TB> b_trans = GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool);
GenerateTransposedMat<TB>(B_extents, MatPadding::kOdd, pool); MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(
B_extents, env.ctx.allocator, MatPadding::kOdd, pool);
const float* add_row = add ? add_storage.PackedScale1() : nullptr; const float* add_row = add ? add_storage.PackedScale1() : nullptr;
@ -151,10 +154,10 @@ void BenchAllMatMul() {
return; return;
} }
ThreadingContext& ctx = ThreadingContext::Get(); ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(), fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
ctx.pools.PinString()); ctx.pools.PinString());
MatMulEnv env(ctx); MatMulEnv env(ctx);
for (size_t batch_size : {1, 4, 128, 512}) { for (size_t batch_size : {1, 4, 128, 512}) {

View File

@ -999,6 +999,8 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot const hn::ScalableTag<float> df; // for CallDot
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
CompressWorkingSet work; CompressWorkingSet work;
std::mt19937 rng; std::mt19937 rng;
rng.seed(12345); rng.seed(12345);
@ -1009,14 +1011,14 @@ struct TestShortDotsT {
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`, // GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
// hence they require padding to one vector. // hence they require padding to one vector.
const size_t padded_num = hwy::RoundUpTo(num, N); const size_t padded_num = hwy::RoundUpTo(num, N);
MatStorageT<float> raw_w("raw_w", padded_num); MatStorageT<float> raw_w("raw_w", padded_num, ctx.allocator);
MatStorageT<float> raw_v("raw_v", padded_num); MatStorageT<float> raw_v("raw_v", padded_num, ctx.allocator);
MatStorageT<Packed> weights("weights", padded_num); MatStorageT<Packed> weights("weights", padded_num, ctx.allocator);
const PackedSpan<Packed> w = weights.Span(); const PackedSpan<Packed> w = weights.Span();
MatStorageT<T> vectors("vectors", padded_num); MatStorageT<T> vectors("vectors", padded_num, ctx.allocator);
const PackedSpan<T> v = vectors.Span(); const PackedSpan<T> v = vectors.Span();
MatStorageT<double> bufs("bufs", num); MatStorageT<double> bufs("bufs", padded_num, ctx.allocator);
double* HWY_RESTRICT buf = bufs.Row(0); double* HWY_RESTRICT buf = bufs.Row(0);
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
@ -1097,14 +1099,12 @@ void TestAllDot() {
constexpr size_t kMaxWorkers = 15; constexpr size_t kMaxWorkers = 15;
// Reset with cap on workers because we only support `kMaxWorkers`. // Limit workers because we only support `kMaxWorkers`.
ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.max_packages = 1; threading_args.max_packages = 1;
threading_args.max_clusters = 1; threading_args.max_clusters = 1;
threading_args.max_lps = kMaxWorkers - 1; threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext::SetArgs(threading_args); ThreadingContext ctx(threading_args);
ThreadingContext& ctx = ThreadingContext::Get();
{ // ensure no profiler zones are active { // ensure no profiler zones are active
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -1116,9 +1116,11 @@ void TestAllDot() {
constexpr size_t kReps = hn::AdjustedReps(40); constexpr size_t kReps = hn::AdjustedReps(40);
const size_t num = 24 * 1024; const size_t num = 24 * 1024;
MatStorageT<float> a("a", Extents2D(kMaxWorkers, num), MatPadding::kOdd); MatStorageT<float> a("a", Extents2D(kMaxWorkers, num), ctx.allocator,
MatStorageT<float> b("b", Extents2D(kMaxWorkers, num), MatPadding::kOdd); MatPadding::kOdd);
MatStorageT<double> bufs("bufs", Extents2D(kMaxWorkers, num), MatStorageT<float> b("b", Extents2D(kMaxWorkers, num), ctx.allocator,
MatPadding::kOdd);
MatStorageT<double> bufs("bufs", Extents2D(kMaxWorkers, num), ctx.allocator,
MatPadding::kOdd); MatPadding::kOdd);
std::array<DotStats, kMaxWorkers> all_stats; std::array<DotStats, kMaxWorkers> all_stats;

View File

@ -26,6 +26,7 @@
#include <memory> #include <memory>
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -68,10 +69,11 @@ FloatPtr SimpleMatVecAdd(const MatStorageT<float>& mat, const FloatPtr& vec,
template <typename MatT, size_t kOuter, size_t kInner> template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset, std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
const Allocator& allocator,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
const Extents2D extents(kOuter, kInner); const Extents2D extents(kOuter, kInner);
auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents, auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents, allocator,
MatPadding::kPacked); MatPadding::kPacked);
FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area()); FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area());
HWY_ASSERT(raw_mat); HWY_ASSERT(raw_mat);
@ -109,10 +111,12 @@ void AssertClose(const FloatPtr& a, const FloatPtr& b) {
} }
void TestMatVecAdd() { void TestMatVecAdd() {
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads()); ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 128 * 3; constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5; constexpr size_t kInner = 128 * 5;
auto mat = GenerateMat<float, kOuter, kInner>(0, pool); auto mat = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
FloatPtr vec = GenerateVec<kInner>(0); FloatPtr vec = GenerateVec<kInner>(0);
FloatPtr add = GenerateVec<kOuter>(0); FloatPtr add = GenerateVec<kOuter>(0);
FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add); FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add);
@ -124,11 +128,13 @@ void TestMatVecAdd() {
} }
void TestTwoMatVecAdd() { void TestTwoMatVecAdd() {
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads()); ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 128 * 3; constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5; constexpr size_t kInner = 128 * 5;
auto mat0 = GenerateMat<float, kOuter, kInner>(0, pool); auto mat0 = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
auto mat1 = GenerateMat<float, kOuter, kInner>(1, pool); auto mat1 = GenerateMat<float, kOuter, kInner>(1, ctx.allocator, pool);
FloatPtr vec = GenerateVec<kInner>(0); FloatPtr vec = GenerateVec<kInner>(0);
FloatPtr add0 = GenerateVec<kOuter>(0); FloatPtr add0 = GenerateVec<kOuter>(0);
FloatPtr add1 = GenerateVec<kOuter>(1); FloatPtr add1 = GenerateVec<kOuter>(1);
@ -145,10 +151,13 @@ void TestTwoMatVecAdd() {
} }
void TestTwoOfsMatVecAddLoop() { void TestTwoOfsMatVecAddLoop() {
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads()); ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
hwy::ThreadPool& pool = ctx.pools.Pool();
constexpr size_t kOuter = 128 * 3; constexpr size_t kOuter = 128 * 3;
constexpr size_t kInner = 128 * 5; constexpr size_t kInner = 128 * 5;
auto mat = GenerateMat<float, kOuter, kInner>(0, pool); auto mat = GenerateMat<float, kOuter, kInner>(0, ctx.allocator, pool);
FloatPtr vec = GenerateVec<kInner>(0); FloatPtr vec = GenerateVec<kInner>(0);
FloatPtr add0 = GenerateVec<kOuter>(0); FloatPtr add0 = GenerateVec<kOuter>(0);
FloatPtr add1 = GenerateVec<kOuter>(1); FloatPtr add1 = GenerateVec<kOuter>(1);

View File

@ -53,8 +53,8 @@ constexpr size_t kNR = 4;
// or less on ISAs with fewer registers, or for the last few rows of A. // or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4; static constexpr size_t kMaxMR = 4;
// Mostly stateless, can be constructed on the fly by weights.cc, but captures // Mostly stateless, can be constructed on the fly by weights.cc. Captures the
// the singleton ThreadingContext to reduce MatMul call overhead. // the ThreadingContext to shorten call sites.
class MMParallel { class MMParallel {
public: public:
static constexpr size_t kMaxPackages = 4; static constexpr size_t kMaxPackages = 4;
@ -251,7 +251,7 @@ class MMStorage {
: // Per-worker copies of `partial` would be wasteful. We instead : // Per-worker copies of `partial` would be wasteful. We instead
// allocate one instance of the maximum matrix extents because threads // allocate one instance of the maximum matrix extents because threads
// write at false-sharing-free granularity. // write at false-sharing-free granularity.
partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), allocator,
MatPadding::kOdd), MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind. // Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
@ -259,7 +259,7 @@ class MMStorage {
// Must be padded, see `DoDecompressA`. // Must be padded, see `DoDecompressA`.
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>( pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd)); "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
if (allocator.ShouldBind()) { if (allocator.ShouldBind()) {
const size_t node = parallel.Node(pkg_idx); const size_t node = parallel.Node(pkg_idx);

View File

@ -91,14 +91,15 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const size_t cols = A.Cols(); const size_t cols = A.Cols();
const size_t B_rows = B.Rows(); const size_t B_rows = B.Rows();
// Round up for DecompressAndZeroPad. // Round up for DecompressAndZeroPad.
MatStorageT<float> a_batch("a_batch", A.Extents(), MatPadding::kOdd); MatStorageT<float> a_batch("a_batch", A.Extents(), env.ctx.allocator,
MatPadding::kOdd);
MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(), MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(),
MatPadding::kOdd); env.ctx.allocator, MatPadding::kOdd);
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows), MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
MatPadding::kOdd); env.ctx.allocator, MatPadding::kOdd);
c_batch.AllocateAndAttachRowPtrs(env.row_ptrs); c_batch.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows), MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
MatPadding::kOdd); env.ctx.allocator, MatPadding::kOdd);
for (size_t m = 0; m < A.Rows(); ++m) { for (size_t m = 0; m < A.Rows(); ++m) {
DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols); DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m), DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
@ -219,17 +220,21 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc); const Extents2D C_extents(rows_ac, cols_bc);
MatStorageT<TA> A(GenerateMat<TA>(A_extents, MatPadding::kOdd, pool)); MatStorageT<TA> A(
GenerateMat<TA>(A_extents, env.ctx.allocator, MatPadding::kOdd, pool));
// Must be packed because we call Span() on it. // Must be packed because we call Span() on it.
MatStorageT<TB> BT( MatStorageT<TB> BT(GenerateTransposedMat<TB>(B_extents, env.ctx.allocator,
GenerateTransposedMat<TB>(B_extents, MatPadding::kPacked, pool)); MatPadding::kPacked, pool));
MatStorageT<TC> C_slow("C_slow", C_extents, MatPadding::kOdd); MatStorageT<TC> C_slow("C_slow", C_extents, env.ctx.allocator,
MatStorageT<TC> C("C", C_extents, MatPadding::kOdd); MatPadding::kOdd);
MatStorageT<TC> C("C", C_extents, env.ctx.allocator, MatPadding::kOdd);
C.AllocateAndAttachRowPtrs(env.row_ptrs); C.AllocateAndAttachRowPtrs(env.row_ptrs);
MatStorageT<float> add_storage = MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), MatPadding::kPacked, pool) add ? GenerateMat<float>(Extents2D(1, cols_bc), env.ctx.allocator,
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked); MatPadding::kPacked, pool)
: MatStorageT<float>("add", Extents2D(), env.ctx.allocator,
MatPadding::kPacked);
add_storage.SetScale(1.0f); add_storage.SetScale(1.0f);
const float* add_row = add ? add_storage.PackedScale1() : nullptr; const float* add_row = add ? add_storage.PackedScale1() : nullptr;
@ -252,12 +257,11 @@ void TestTiny() {
if (HWY_TARGET != first_target) return; if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) { for (size_t max_packages : {1, 2}) {
ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue; threading_args.bind = Tristate::kTrue;
threading_args.max_packages = max_packages; threading_args.max_packages = max_packages;
ThreadingContext::SetArgs(threading_args); ThreadingContext ctx(threading_args);
MatMulEnv env(ThreadingContext::Get()); MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools; NestedPools& pools = env.ctx.pools;
if constexpr (GEMMA_DISABLE_TOPOLOGY) { if constexpr (GEMMA_DISABLE_TOPOLOGY) {
@ -291,11 +295,10 @@ void TestAllMatMul() {
return; return;
} }
ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue; threading_args.bind = Tristate::kTrue;
ThreadingContext::SetArgs(threading_args); ThreadingContext ctx(threading_args);
MatMulEnv env(ThreadingContext::Get()); MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools; NestedPools& pools = env.ctx.pools;
pools.MaybeStartSpinning(threading_args.spin); pools.MaybeStartSpinning(threading_args.spin);

View File

@ -1018,13 +1018,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Input has 4096 (64*64) rows, output has 256 (16*16) rows // Input has 4096 (64*64) rows, output has 256 (16*16) rows
// Each output row is the average of a 4x4 block of input rows // Each output row is the average of a 4x4 block of input rows
template <typename T> template <typename T>
MatStorageT<T> AvgPool4x4(MatStorageT<T>& input) { MatStorageT<T> AvgPool4x4(MatStorageT<T>& input, const Allocator& allocator) {
const Extents2D extents = input.Extents(); const Extents2D extents = input.Extents();
// Input validation // Input validation
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
// Create output with 256 rows and same number of columns // Create output with 256 rows and same number of columns
const size_t out_rows = 256; // 16 * 16 = 256 output rows const size_t out_rows = 256; // 16 * 16 = 256 output rows
MatStorageT<T> result("pool4x4", Extents2D(out_rows, extents.cols), MatStorageT<T> result("pool4x4", Extents2D(out_rows, extents.cols), allocator,
MatPadding::kOdd); MatPadding::kOdd);
const size_t input_dim = 64; // Input is 64×64 const size_t input_dim = 64; // Input is 64×64
const size_t output_dim = 16; // Output is 16×16 const size_t output_dim = 16; // Output is 16×16

View File

@ -26,9 +26,10 @@
namespace gcpp { namespace gcpp {
static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale( static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) { const Allocator& allocator, size_t qkv_dim, bool half_rope,
double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2); MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2, allocator);
for (size_t dim = 0; dim < rope_dim / 2; ++dim) { for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const double freq_exponents = const double freq_exponents =
static_cast<double>(2 * dim) / static_cast<double>(rope_dim); static_cast<double>(2 * dim) / static_cast<double>(rope_dim);

View File

@ -347,10 +347,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
} }
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, const ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B)); ChooseWrapping(Model::GEMMA2_9B));
const size_t dim_qkv = config.layer_configs[0].qkv_dim; const size_t dim_qkv = config.layer_configs[0].qkv_dim;
MatStorageT<float> x("x", dim_qkv); MatStorageT<float> x("x", dim_qkv, ctx.allocator);
std::mt19937 gen; std::mt19937 gen;
gen.seed(0x12345678); gen.seed(0x12345678);
@ -364,13 +366,13 @@ void TestRopeAndMulBy() {
const float qmul = AttentionActivations::ChooseQueryScale(config); const float qmul = AttentionActivations::ChooseQueryScale(config);
constexpr float kmul = 1.0f; constexpr float kmul = 1.0f;
MatStorageT<float> qexpected("qexpected", dim_qkv); MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);
MatStorageT<float> qactual("qactual", dim_qkv); MatStorageT<float> qactual("qactual", dim_qkv, ctx.allocator);
MatStorageT<float> kexpected("kexpected", dim_qkv); MatStorageT<float> kexpected("kexpected", dim_qkv, ctx.allocator);
MatStorageT<float> kactual("kactual", dim_qkv); MatStorageT<float> kactual("kactual", dim_qkv, ctx.allocator);
MatStorageT<float> kactual2("kactual2", dim_qkv); MatStorageT<float> kactual2("kactual2", dim_qkv, ctx.allocator);
MatStorageT<float> inv_timescale = CreateInvTimescale( MatStorageT<float> inv_timescale = CreateInvTimescale(
config.layer_configs[0].qkv_dim, ctx.allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope); config.layer_configs[0].post_qk == PostQKType::HalfRope);
// Assert VectorizedRope computation is same as regular rope at different pos. // Assert VectorizedRope computation is same as regular rope at different pos.
for (size_t pos = 1; pos < 500; pos++) { for (size_t pos = 1; pos < 500; pos++) {

View File

@ -21,7 +21,7 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
image_tokens_ = std::make_unique<ImageTokens>( image_tokens_ = std::make_unique<ImageTokens>(
"image", Extents2D(config.vit_config.seq_len, config.model_dim), "image", Extents2D(config.vit_config.seq_len, config.model_dim),
MatPadding::kPacked); env_->Env().ctx.allocator, MatPadding::kPacked);
image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs); image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs);
Image image; Image image;
HWY_ASSERT(image.ReadPPM(path)); HWY_ASSERT(image.ReadPPM(path));

View File

@ -186,7 +186,7 @@ class GemmaModel {
image_tokens_.reset(new gcpp::ImageTokens( image_tokens_.reset(new gcpp::ImageTokens(
"image_tokens", "image_tokens",
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim), gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
gcpp::MatPadding::kOdd)); env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(), gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(),
.verbosity = 0}; .verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(), gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),

View File

@ -78,10 +78,10 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
} }
} }
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator,
MatPadding padding) {
const bool is_nuq = mat.GetType() == Type::kNUQ; const bool is_nuq = mat.GetType() == Type::kNUQ;
if (is_nuq) padding = MatPadding::kPacked; if (is_nuq) padding = MatPadding::kPacked;
const Allocator& allocator = ThreadingContext::Get().allocator;
const size_t stride = const size_t stride =
Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes()); Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes());
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride; const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;

View File

@ -443,7 +443,7 @@ class MatOwner {
// Allocates the type/extents indicated by `mat` and sets its pointer. // Allocates the type/extents indicated by `mat` and sets its pointer.
// Ignores `padding` for NUQ tensors, which are always packed. // Ignores `padding` for NUQ tensors, which are always packed.
// Thread-compatible, weights are allocated in parallel. // Thread-compatible, weights are allocated in parallel.
void AllocateFor(MatPtr& mat, MatPadding padding); void AllocateFor(MatPtr& mat, const Allocator& allocator, MatPadding padding);
private: private:
AlignedPtr<uint8_t[]> storage_; AlignedPtr<uint8_t[]> storage_;
@ -455,13 +455,14 @@ class MatOwner {
template <typename MatT> template <typename MatT>
class MatStorageT : public MatPtrT<MatT> { class MatStorageT : public MatPtrT<MatT> {
public: public:
MatStorageT(const char* name, Extents2D extents, MatPadding padding) MatStorageT(const char* name, Extents2D extents, const Allocator& allocator,
MatPadding padding)
: MatPtrT<MatT>(name, extents) { : MatPtrT<MatT>(name, extents) {
if (extents.Area() != 0) owner_.AllocateFor(*this, padding); if (extents.Area() != 0) owner_.AllocateFor(*this, allocator, padding);
} }
// Shorthand for 1D tensors: packing does not help, hence `kPacked`. // Shorthand for 1D tensors: packing does not help, hence `kPacked`.
MatStorageT(const char* name, size_t cols) MatStorageT(const char* name, size_t cols, const Allocator& allocator)
: MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {} : MatStorageT(name, Extents2D(1, cols), allocator, MatPadding::kPacked) {}
~MatStorageT() = default; ~MatStorageT() = default;
// Allow move for KVCache. // Allow move for KVCache.
@ -472,5 +473,31 @@ class MatStorageT : public MatPtrT<MatT> {
MatOwner owner_; MatOwner owner_;
}; };
// Helper for initializing members which are `MatStorageT<T>`: avoids having to
// specify Extents2D and MatPadding at each call site.
class MatFactory {
public:
// The constructor captures all the necessary arguments.
MatFactory(const char* name, size_t rows, size_t cols,
const Allocator& allocator, MatPadding padding = MatPadding::kOdd)
: name_(name),
extents_(rows, cols),
allocator_(allocator),
padding_(padding) {}
// Templated conversion so we do not have to specify the type in the
// member initializer.
template <typename T>
operator MatStorageT<T>() const {
return MatStorageT<T>(name_.c_str(), extents_, allocator_, padding_);
}
private:
const std::string name_;
Extents2D extents_;
const Allocator& allocator_;
MatPadding padding_;
};
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_

View File

@ -15,60 +15,13 @@
#include "util/threading_context.h" #include "util/threading_context.h"
#include <memory>
#include <mutex> // NOLINT
#include "hwy/base.h" // HWY_ASSERT, HWY_UNLIKELY
#include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
static ThreadingArgs s_args; ThreadingContext::ThreadingContext(const ThreadingArgs& args)
// Cannot use magic static because that does not support `Invalidate`, hence : topology(BoundedSlice(args.skip_packages, args.max_packages),
// allocate manually. BoundedSlice(args.skip_clusters, args.max_clusters),
static std::unique_ptr<ThreadingContext> s_ctx; BoundedSlice(args.skip_lps, args.max_lps)),
static std::mutex s_ctx_mutex; allocator(topology, args.bind != Tristate::kFalse),
pools(topology, allocator, args.max_threads, args.pin) {}
/*static*/ void ThreadingContext::SetArgs(const ThreadingArgs& args) {
s_ctx_mutex.lock();
HWY_ASSERT(!s_ctx); // Ensure not already initialized, else this is too late.
s_args = args;
s_ctx_mutex.unlock();
}
/*static*/ bool ThreadingContext::IsInitialized() {
s_ctx_mutex.lock();
const bool initialized = !!s_ctx;
s_ctx_mutex.unlock();
return initialized;
}
/*static*/ ThreadingContext& ThreadingContext::Get() {
PROFILER_FUNC;
// We do not bother with double-checked locking because it requires an
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
// callers can cache the result and call less often.
s_ctx_mutex.lock();
if (HWY_UNLIKELY(!s_ctx)) {
s_ctx = std::make_unique<ThreadingContext>(PrivateToken());
}
s_ctx_mutex.unlock();
return *s_ctx;
}
/*static*/ void ThreadingContext::ThreadHostileInvalidate() {
// Deliberately avoid taking the lock so that tsan can warn if this is
// called concurrently with other calls to `Get`.
s_ctx.reset();
}
// WARNING: called with `s_ctx_mutex` held. Calling `SetArgs` or `Get` would
// deadlock.
ThreadingContext::ThreadingContext(ThreadingContext::PrivateToken)
: topology(BoundedSlice(s_args.skip_packages, s_args.max_packages),
BoundedSlice(s_args.skip_clusters, s_args.max_clusters),
BoundedSlice(s_args.skip_lps, s_args.max_lps)),
allocator(topology, s_args.bind != Tristate::kFalse),
pools(topology, allocator, s_args.max_threads, s_args.pin) {}
} // namespace gcpp } // namespace gcpp

View File

@ -85,43 +85,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
} }
}; };
// Lazily-initialized singleton with support for passing in arguments from struct ThreadingContext {
// `ThreadingArgs` and re-initializing with different arguments. // Expected to be called early in the program, before threading starts.
class ThreadingContext { explicit ThreadingContext(const ThreadingArgs& args);
struct PrivateToken {}; // avoids constructing directly
public:
// If not called, default arguments are used when `Get` initializes the
// singleton. Must not be called after `Get`, unless after a call to
// `ThreadHostileInvalidate`, because otherwise initialization already
// happened and the arguments would have no effect. Thread-safe, though this
// is expected to be called early in the program, before threading starts.
static void SetArgs(const ThreadingArgs& args);
// Returns whether `Get()` has already been called, typically used to avoid
// calling `SetArgs` after that, because it would assert.
static bool IsInitialized();
// Returns a reference to the singleton after initializing it if necessary.
// When initializing, uses the args passed to `SetArgs`, or defaults.
//
// It is safe to call this concurrently with other `Get`, but not with
// `SetArgs`, because that will warn if called after this, nor with
// `ThreadHostileInvalidate`, because that will invalidate the reference which
// callers of this may still be using. Such usage only occurs in tests,
// hence we prefer not to pull `std::shared_ptr` into the interface.
//
// To reduce overhead, callers should cache the result and call less often.
static ThreadingContext& Get();
// Invalidates the singleton before or after a call to `Get`. This allows
// changing the arguments between tests. Callers must again call `Get`
// afterwards to obtain an instance. WARNING: must not be called concurrently
// with other calls to `Get` and usages of its return value.
// Also useful to suppress memory leak warnings in tests.
static void ThreadHostileInvalidate();
explicit ThreadingContext(PrivateToken); // only called via `Get`.
BoundedTopology topology; BoundedTopology topology;
Allocator allocator; Allocator allocator;

View File

@ -280,8 +280,6 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
} }
const double t1 = hwy::platform::Now(); const double t1 = hwy::platform::Now();
// TODO(janwas): enable after Highway update
#if 0
if (pool.AutoTuneComplete()) { if (pool.AutoTuneComplete()) {
hwy::Span<hwy::CostDistribution> cd = pool.AutoTuneCosts(); hwy::Span<hwy::CostDistribution> cd = pool.AutoTuneCosts();
std::vector<double> costs; std::vector<double> costs;
@ -308,10 +306,6 @@ std::vector<uint64_t> MeasureForkJoin(hwy::ThreadPool& pool) {
} else { } else {
HWY_WARN("Auto-tuning did not complete yet."); HWY_WARN("Auto-tuning did not complete yet.");
} }
#else
(void)t0;
(void)t1;
#endif
char cpu100[100]; char cpu100[100];
static const bool have_stop = hwy::platform::HaveTimerStop(cpu100); static const bool have_stop = hwy::platform::HaveTimerStop(cpu100);
@ -383,7 +377,9 @@ TEST(ThreadingTest, BenchJoin) {
} }
}; };
NestedPools& pools = ThreadingContext::Get().pools; ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
NestedPools& pools = ctx.pools;
// Use last package because the main thread has been pinned to it. // Use last package because the main thread has been pinned to it.
const size_t pkg_idx = pools.NumPackages() - 1; const size_t pkg_idx = pools.NumPackages() - 1;