diff --git a/BUILD.bazel b/BUILD.bazel index f7fc5c8..4421189 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -389,6 +389,7 @@ cc_test( deps = [ ":mat", ":ops", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", "@highway//:hwy", diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 428e9cf..7b5bc15 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -57,6 +57,9 @@ class SbsWriterImpl : public ISbsWriter { template void InsertT(const char* name, F32Span weights, const TensorInfo& tensor_info) { + // TODO(janwas): 1D parallel-for. + hwy::ThreadPool& pool = ctx_.pools.Pool(); + MatPtrT mat(name, ExtentsFromInfo(&tensor_info)); // SFP and NUQ (which uses SFP for cluster centers) have a limited range // and depending on the input values may require rescaling. Scaling is @@ -73,13 +76,13 @@ class SbsWriterImpl : public ISbsWriter { mat.AppendTo(serialized_mat_ptrs_); mat_owners_.push_back(MatOwner()); - mat_owners_.back().AllocateFor(mat, MatPadding::kPacked); + mat_owners_.back().AllocateFor(mat, ctx_.allocator, MatPadding::kPacked); // Handle gemma_export_test's MockArray. Write blobs so that the test // succeeds, but we only have 10 floats, not the full tensor. if (weights.size() == 10 && mat.Extents().Area() != 10) { Compress(weights.data(), weights.size(), working_set_, mat.Span(), - /*packed_ofs=*/0, pool_); + /*packed_ofs=*/0, pool); writer_.Add(name, mat.Packed(), mat.ElementBytes() * 10); return; } @@ -89,12 +92,12 @@ class SbsWriterImpl : public ISbsWriter { TypeName(TypeEnum())); HWY_ASSERT(weights.size() == mat.Extents().Area()); Compress(weights.data(), weights.size(), working_set_, mat.Span(), - /*packed_ofs=*/0, pool_); + /*packed_ofs=*/0, pool); writer_.Add(name, mat.Packed(), mat.PackedBytes()); } public: - SbsWriterImpl() : pool_(ThreadingContext::Get().pools.Pool()) {} + SbsWriterImpl() : ctx_(ThreadingArgs()) {} void Insert(const char* name, F32Span weights, Type type, const TensorInfo& tensor_info) override { @@ -122,18 +125,18 @@ class SbsWriterImpl : public ISbsWriter { const GemmaTokenizer tokenizer( tokenizer_path.empty() ? kMockTokenizer : ReadFileToString(Path(tokenizer_path))); - WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_, pool_, - gcpp::Path(path)); + WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_, + ctx_.pools.Pool(), gcpp::Path(path)); } - hwy::ThreadPool& pool_; + ThreadingContext ctx_; std::vector mat_owners_; CompressWorkingSet working_set_; BlobWriter writer_; std::vector serialized_mat_ptrs_; }; -ISbsWriter* NewSbsWriter() { return new SbsWriterImpl; } +ISbsWriter* NewSbsWriter() { return new SbsWriterImpl(); } } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 3af8e78..7c4f854 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -69,12 +69,13 @@ void ForeachPackedAndRawType() { // Generates inputs: deterministic, within max SfpStream range. template -MatStorageT GenerateMat(const Extents2D& extents, MatPadding padding, +MatStorageT GenerateMat(const Extents2D& extents, + const Allocator& allocator, MatPadding padding, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; ws.tls.resize(pool.NumWorkers()); - MatStorageT raw("raw", extents, MatPadding::kPacked); - MatStorageT compressed("mat", extents, padding); + MatStorageT raw("raw", extents, allocator, MatPadding::kPacked); + MatStorageT compressed("mat", extents, allocator, padding); const float scale = SfpStream::kMax / extents.Area(); pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { float* HWY_RESTRICT row = raw.Row(r); @@ -95,12 +96,13 @@ MatStorageT GenerateMat(const Extents2D& extents, MatPadding padding, // Same, but `extents` describes the transposed matrix. template MatStorageT GenerateTransposedMat(const Extents2D extents, + const Allocator& allocator, MatPadding padding, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; ws.tls.resize(pool.NumWorkers()); - MatStorageT raw("raw", extents, MatPadding::kPacked); - MatStorageT compressed("trans", extents, padding); + MatStorageT raw("raw", extents, allocator, MatPadding::kPacked); + MatStorageT compressed("trans", extents, allocator, padding); const float scale = SfpStream::kMax / extents.Area(); pool.Run(0, extents.rows, [&](const size_t r, size_t thread) { float* HWY_RESTRICT row = raw.Row(r); diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 738070a..f7c614e 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -74,7 +74,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference()); + KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference(), + env.MutableEnv().ctx.allocator); float entropy = ComputeCrossEntropy(*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.MutableEnv(), env.Verbosity()); diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 98a4761..689062d 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -50,14 +50,13 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) - : env_(MakeMatMulEnv(threading, inference)), - gemma_(loader, inference, env_.ctx.pools) { + : ctx_(threading), env_(ctx_), gemma_(loader, inference, ctx_) { const ModelConfig& config = gemma_.GetModelConfig(); // Only allocate one for starters because GenerateBatch might not be called. - kv_caches_.push_back(KVCache(config, inference)); + kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); if (inference.verbosity >= 2) { - ShowConfig(loader, threading, inference, config); + ShowConfig(loader, threading, inference, config, ctx_); } InitGenerator(inference, gen_); @@ -141,7 +140,8 @@ std::vector GemmaEnv::BatchQueryModel( // Ensure we have at least one KVCache per query. while (kv_caches_.size() < num_queries) { - kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference())); + kv_caches_.push_back( + KVCache(gemma_.GetModelConfig(), gemma_.Inference(), ctx_.allocator)); } const hwy::Span kv_caches(&kv_caches_[0], num_queries); @@ -228,7 +228,8 @@ static constexpr const char* CompiledConfig() { } void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference, const ModelConfig& config) { + const InferenceArgs& inference, const ModelConfig& config, + const ThreadingContext& ctx) { threading.Print(inference.verbosity); loader.Print(inference.verbosity); inference.Print(inference.verbosity); @@ -241,7 +242,6 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, char* dt = ctime(&now); // NOLINT char cpu100[100] = "unknown"; (void)hwy::platform::GetCpuString(cpu100); - const ThreadingContext& ctx = ThreadingContext::Get(); fprintf(stderr, "Date & Time : %s" // dt includes \n diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 176267e..8f1a238 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -49,9 +49,6 @@ class GemmaEnv { GemmaEnv(int argc, char** argv); GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference); - // Avoid memory leaks in test. - ~GemmaEnv() { ThreadingContext::ThreadHostileInvalidate(); } - MatMulEnv& Env() { return env_; } size_t MaxGeneratedTokens() const { @@ -115,6 +112,7 @@ class GemmaEnv { MatMulEnv& MutableEnv() { return env_; } private: + ThreadingContext ctx_; MatMulEnv env_; Gemma gemma_; std::mt19937 gen_; // Random number generator. @@ -126,7 +124,8 @@ class GemmaEnv { void LogSpeedStats(double time_start, size_t total_tokens); void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, - const InferenceArgs& inference, const ModelConfig& config); + const InferenceArgs& inference, const ModelConfig& config, + const ThreadingContext& ctx); void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference); diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index e4cfcd5..8f1a7b3 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -51,9 +51,10 @@ int main(int argc, char** argv) { } // Instantiate model and KV Cache - gcpp::MatMulEnv env(MakeMatMulEnv(threading, inference)); - gcpp::Gemma gemma(loader, inference, env.ctx.pools); - gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference); + gcpp::ThreadingContext ctx(gcpp::UpdateArgs(threading, inference)); + gcpp::MatMulEnv env(ctx); + gcpp::Gemma gemma(loader, inference, ctx); + gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator); size_t generated = 0; // Initialize random number generator diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index b5eab41..48290e8 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -35,9 +35,10 @@ class SimplifiedGemma { SimplifiedGemma(const gcpp::LoaderArgs& loader, const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) - : env_(MakeMatMulEnv(threading, inference)), - gemma_(loader, inference, env_.ctx.pools), - kv_cache_(gemma_.GetModelConfig(), inference) { + : ctx_(UpdateArgs(threading, inference)), + env_(ctx_), + gemma_(loader, inference, ctx_), + kv_cache_(gemma_.GetModelConfig(), inference, ctx_.allocator) { // Initialize random number generator std::random_device rd; gen_.seed(rd()); @@ -88,6 +89,7 @@ class SimplifiedGemma { ~SimplifiedGemma() = default; private: + gcpp::ThreadingContext ctx_; gcpp::MatMulEnv env_; gcpp::Gemma gemma_; gcpp::KVCache kv_cache_; diff --git a/gemma/activations.h b/gemma/activations.h index 45096cc..ccb1a59 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -35,13 +35,15 @@ namespace gcpp { struct GriffinActivations { GriffinActivations(const ModelConfig& config, size_t batch_size, - MatPadding pad) - : griffin_x("griffin_x", Extents2D(batch_size, config.model_dim), pad), - griffin_y("griffin_y", Extents2D(batch_size, config.model_dim), pad), - griffin_gate_x("griffin_gate_x", - Extents2D(batch_size, config.model_dim), pad), - griffin_multiplier("griffin_mul", - Extents2D(batch_size, config.model_dim), pad) {} + const Allocator& allocator) + : griffin_x( + MatFactory("griffin_x", batch_size, config.model_dim, allocator)), + griffin_y( + MatFactory("griffin_y", batch_size, config.model_dim, allocator)), + griffin_gate_x(MatFactory("griffin_gate_x", batch_size, + config.model_dim, allocator)), + griffin_multiplier(MatFactory("griffin_mul", batch_size, + config.model_dim, allocator)) {} void SetBatchSize(size_t batch_size) { if (griffin_x.Rows() == 0) return; @@ -70,34 +72,34 @@ struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, - size_t batch_size, size_t seq_len, MatPadding pad, + size_t batch_size, size_t seq_len, const Allocator& allocator, std::vector>& row_ptrs) : config(config), // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA // and does not use an external KV cache. - q("q", - Extents2D(batch_size, - config.vocab_size == 0 - ? layer_config.heads * 3 * layer_config.qkv_dim - : layer_config.heads * layer_config.qkv_dim), - pad), + q(MatFactory("q", batch_size, + config.vocab_size == 0 + ? layer_config.heads * 3 * layer_config.qkv_dim + : layer_config.heads * layer_config.qkv_dim, + allocator)), - pre_att_rms_out("pre_att_rms_out", - Extents2D(batch_size, config.model_dim), pad), - att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad), - att_out( - "att_out", - Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), - pad), - att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad), + pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, + config.model_dim, allocator)), + att(MatFactory("att", batch_size, layer_config.heads * seq_len, + allocator)), + att_out(MatFactory("att_out", batch_size, + layer_config.heads * layer_config.qkv_dim, + allocator)), + att_sums( + MatFactory("att_sums", batch_size, config.model_dim, allocator)), inv_timescale( - CreateInvTimescale(layer_config.qkv_dim, + CreateInvTimescale(allocator, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope)), inv_timescale_global(CreateInvTimescale( - layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, - 1000000.0)), + allocator, layer_config.qkv_dim, + layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), div_seq_len(static_cast(seq_len)), div_heads(static_cast(layer_config.heads)), @@ -149,21 +151,23 @@ struct AttentionActivations { struct Activations { Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, + const Allocator& allocator, std::vector>& row_ptrs) : layer_config(config.layer_configs[0]), - x("x", Extents2D(batch_size, config.model_dim), pad_), - logits("logits", Extents2D(batch_size, config.vocab_size), pad_), + x(MatFactory("x", batch_size, config.model_dim, allocator)), + logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), - pre_ffw_rms_out("pre_ffw_rms_out", - Extents2D(batch_size, config.model_dim), pad_), - C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), - C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), - ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_), + pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, + config.model_dim, allocator)), + C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, allocator)), + C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, allocator)), + ffw_out(MatFactory("ffw_out", batch_size, config.model_dim, allocator)), - attention(config, layer_config, batch_size, seq_len, pad_, row_ptrs), + attention(config, layer_config, batch_size, seq_len, allocator, + row_ptrs), griffin(config, config.model == Model::GRIFFIN_2B ? batch_size : 0, - pad_) { + allocator) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. @@ -193,15 +197,14 @@ struct Activations { } const LayerConfig& layer_config; - const Extents2D none_ = Extents2D(); - const MatPadding pad_ = MatPadding::kOdd; MatStorageT x; // input MatStorageT logits; // Gated FFW MatStorageT pre_ffw_rms_out; - MatStorageT C1; // TODO: BF16 after Activation() supports it + // Norm may be large, so prefer to keep as f32. + MatStorageT C1; MatStorageT C2; MatStorageT ffw_out; diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index e540bb4..e8329c2 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -43,8 +43,10 @@ namespace gcpp { // ConversationData constructor implementation ConversationData::ConversationData(const ModelConfig& model_config, - const InferenceArgs& inference_args) - : kv_cache(std::make_unique(model_config, inference_args)), + const InferenceArgs& inference_args, + const Allocator& allocator) + : kv_cache( + std::make_unique(model_config, inference_args, allocator)), abs_pos(0) {} // ConversationData copy constructor implementation @@ -101,15 +103,16 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, int max_generated_tokens) : inference_args(inference_args), threading_args(threading_args), - matmul_env(MakeMatMulEnv(threading_args, inference_args)), + ctx(UpdateArgs(threading_args, inference_args)), + matmul_env(ctx), active_conversation_name("default"), - model(loader, inference_args, matmul_env.ctx.pools) { + model(loader, inference_args, matmul_env.ctx) { std::stringstream ss; LogDebug("Creating initial ConversationData"); // Create the initial ConversationData object using make_shared active_conversation = std::make_shared( - model.GetModelConfig(), inference_args); + model.GetModelConfig(), inference_args, ctx.allocator); LogDebug( "Storing initial ConversationData in conversation_cache[\"default\"]"); @@ -188,7 +191,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, ? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim), model_config.model_dim) : Extents2D(0, 0), - MatPadding::kOdd); + ctx.allocator, MatPadding::kOdd); if (image_data != nullptr) { HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA || model_config.wrapping == PromptWrapping::GEMMA_VLM); diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index c954da3..fcf3529 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -41,7 +41,8 @@ namespace gcpp { // Struct to hold data for a single conversation thread struct ConversationData { ConversationData(const ModelConfig& model_config, - const InferenceArgs& inference_args); + const InferenceArgs& inference_args, + const Allocator& allocator); ConversationData(const ConversationData& other); std::unique_ptr kv_cache; @@ -178,8 +179,8 @@ class GemmaContext { // rewind to initial state. active_conversation->abs_pos = 0; // Replace the cache within the current ConversationData object - active_conversation->kv_cache = - std::make_unique(model.GetModelConfig(), inference_args); + active_conversation->kv_cache = std::make_unique( + model.GetModelConfig(), inference_args, ctx.allocator); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); } else { @@ -197,7 +198,7 @@ class GemmaContext { LogDebug("Creating new conversation"); // Create a new ConversationData object using make_shared conversation_cache[name] = std::make_shared( - model.GetModelConfig(), inference_args); + model.GetModelConfig(), inference_args, ctx.allocator); return true; } @@ -280,6 +281,7 @@ class GemmaContext { // Cached args (remain global for the context) InferenceArgs inference_args; ThreadingArgs threading_args; + ThreadingContext ctx; MatMulEnv matmul_env; std::string active_conversation_name; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a5718d2..109ccad 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -537,7 +537,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const WeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { Activations activations(config, runtime_config.prefill_tbatch_size, - kv_cache.SeqLen(), env.row_ptrs); + kv_cache.SeqLen(), env.ctx.allocator, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); @@ -555,7 +555,8 @@ void GenerateBatchT(const ModelConfig& config, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); Activations activations(config, max_batch_size, - all_queries[0].kv_cache.SeqLen(), env.row_ptrs); + all_queries[0].kv_cache.SeqLen(), env.ctx.allocator, + env.row_ptrs); for (size_t start = 0; start < all_queries.NumQueries(); start += runtime_config.decode_qbatch_size) { @@ -579,7 +580,7 @@ void GenerateImageTokensT(const ModelConfig& config, prefill_runtime_config.prefill_tbatch_size = num_tokens / (vit_config.pool_dim * vit_config.pool_dim); Activations prefill_activations(vit_config, num_tokens, num_tokens, - env.row_ptrs); + env.ctx.allocator, env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); @@ -596,28 +597,14 @@ HWY_EXPORT(GenerateSingleT); HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateImageTokensT); -MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args, - const InferenceArgs& inference_args) { - if (inference_args.decode_qbatch_size >= 256) { - ThreadingArgs copy = threading_args; - copy.max_packages = 1; - ThreadingContext::SetArgs(copy); - } else { - ThreadingContext::SetArgs(threading_args); - } - - return MatMulEnv(ThreadingContext::Get()); -} - Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - NestedPools& pools) + ThreadingContext& ctx) : reader_(loader.weights), model_(reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), inference_(inference) { - weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, - pools.Pool()); + weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx); reader_.CloseFile(); } diff --git a/gemma/gemma.h b/gemma/gemma.h index dfcb2ee..43af21e 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -225,18 +225,15 @@ struct TimingInfo { size_t tokens_generated = 0; }; -// Returns the `MatMulEnv` after calling `SetArgs`. -MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args, - const InferenceArgs& inference_args); - // After construction, all methods are const and thread-compatible if using -// separate MatMulEnv for each thread. +// separate ThreadingContext for each thread. class Gemma { public: // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. - // `pools` are used to parallelize loading. + // `ctx` is only used to read tensors, but it is typically also referenced + // by the `MatMulEnv` passed to the Generate* methods. Gemma(const LoaderArgs& loader, const InferenceArgs& inference, - NestedPools& pools); + ThreadingContext& ctx); ~Gemma(); // TODO: rename to Config() diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 9d543c4..df950a6 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -256,6 +256,16 @@ struct InferenceArgs : public ArgsBase { } }; +static inline ThreadingArgs UpdateArgs(const ThreadingArgs& threading_args, + const InferenceArgs& inference_args) { + if (inference_args.decode_qbatch_size >= 256) { + ThreadingArgs copy = threading_args; + copy.max_packages = 1; + return copy; + } + return threading_args; +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index a2a577f..9d107e8 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -57,20 +57,24 @@ static size_t CappedSeqLen(const ModelConfig& config, } KVCache::KVCache(const Extents2D& conv1d_extents, - const Extents2D& rglru_extents, const Extents2D& kv_extents) - : conv1d_cache("conv1d_cache", conv1d_extents, MatPadding::kOdd), - rglru_cache("rglru_cache", rglru_extents, MatPadding::kOdd), - kv_cache("kv", kv_extents, MatPadding::kOdd) {} + const Extents2D& rglru_extents, const Extents2D& kv_extents, + const Allocator& allocator) + : conv1d_cache("conv1d_cache", conv1d_extents, allocator, MatPadding::kOdd), + rglru_cache("rglru_cache", rglru_extents, allocator, MatPadding::kOdd), + kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), + allocator_(allocator) {} -KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args) - : KVCache(Extents2D(GriffinLayers(config), GriffinConv1dCols(config)), - Extents2D(GriffinLayers(config), config.model_dim), - Extents2D(CappedSeqLen(config, inference_args), - config.KVCacheCols())) {} +KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, + const Allocator& allocator) + : KVCache( + Extents2D(GriffinLayers(config), GriffinConv1dCols(config)), + Extents2D(GriffinLayers(config), config.model_dim), + Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), + allocator) {} KVCache KVCache::Copy() { KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(), - kv_cache.Extents()); + kv_cache.Extents(), allocator_); if (conv1d_cache.Rows() != 0) { CopyMat(conv1d_cache, copy.conv1d_cache); diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 3de9e7d..7b5b88d 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -28,7 +28,8 @@ namespace gcpp { using KV_t = float; struct KVCache { - KVCache(const ModelConfig& config, const InferenceArgs& inference_args); + KVCache(const ModelConfig& config, const InferenceArgs& inference_args, + const Allocator& allocator); // Returns a deep copy of the KVCache. Use explicit function instead of // copy ctor to make the cost explicit. @@ -47,9 +48,11 @@ struct KVCache { MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] private: + const Allocator& allocator_; + // For use by other ctor and Copy() KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents, - const Extents2D& kv_extents); + const Extents2D& kv_extents, const Allocator& allocator); }; } // namespace gcpp diff --git a/gemma/run.cc b/gemma/run.cc index 5a2ba58..cd72d63 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -110,7 +110,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim), config.model_dim) : Extents2D(0, 0), - MatPadding::kOdd); + env.ctx.allocator, MatPadding::kOdd); image_tokens.AllocateAndAttachRowPtrs(env.row_ptrs); if (have_image) { HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA || @@ -254,10 +254,11 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference) { PROFILER_ZONE("Run.misc"); - MatMulEnv env(MakeMatMulEnv(threading, inference)); + ThreadingContext ctx(UpdateArgs(threading, inference)); + MatMulEnv env(ctx); if (inference.verbosity >= 2) env.print_best = true; - const Gemma gemma(loader, inference, env.ctx.pools); - KVCache kv_cache(gemma.GetModelConfig(), inference); + const Gemma gemma(loader, inference, ctx); + KVCache kv_cache(gemma.GetModelConfig(), inference, ctx.allocator); if (inference.verbosity >= 1) { std::string instructions = @@ -284,7 +285,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, if (inference.IsInteractive()) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, threading, inference, gemma.GetModelConfig()); + ShowConfig(loader, threading, inference, gemma.GetModelConfig(), ctx); std::cout << "\n" << instructions << "\n"; } } diff --git a/gemma/vit.cc b/gemma/vit.cc index 8efed43..82838f2 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -80,11 +80,11 @@ class VitAttention { // Shift Q, K, VT to MatStorageT. MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), - MatPadding::kPacked); - MatStorageT K("K2", Extents2D(seq_len, qkv_dim), + env_.ctx.allocator, MatPadding::kPacked); + MatStorageT K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator, MatPadding::kPacked); MatStorageT C("C2", Extents2D(num_tokens_, seq_len), - MatPadding::kPacked); + env_.ctx.allocator, MatPadding::kPacked); // Initialize att_out to zero prior to head loop. ZeroInit(activations_.attention.att_out); @@ -294,7 +294,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, // image_patches is (256, 14 * 14 * 3) // Must be padded, see `DoDecompressA`. MatStorageT image_patches("patches", Extents2D(num_tokens, patch_size), - MatPadding::kOdd); + env.ctx.allocator, MatPadding::kOdd); for (size_t i = 0; i < num_tokens; ++i) { image.GetPatch(i, image_patches.Row(i)); } @@ -329,7 +329,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, weights.vit_encoder_norm_bias, activations.x); if (model_config.wrapping == PromptWrapping::GEMMA_VLM) { - activations.x = AvgPool4x4(activations.x); + activations.x = AvgPool4x4(activations.x, env.ctx.allocator); // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { diff --git a/gemma/weights.cc b/gemma/weights.cc index 3fdb2c1..b205bd1 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -44,7 +44,8 @@ namespace gcpp { // Copies att_weights from `attn_vec_einsum_w`. -void LayerWeightsPtrs::InitAttWeights(std::vector& mat_owners) { +void LayerWeightsPtrs::InitAttWeights(std::vector& mat_owners, + const Allocator& allocator) { // We only use this tensor for Gemma layers. if (layer_config.type != LayerAttentionType::kGemma) return; @@ -71,7 +72,7 @@ void LayerWeightsPtrs::InitAttWeights(std::vector& mat_owners) { static std::mutex m; std::lock_guard lock(m); mat_owners.push_back(MatOwner()); - mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd); + mat_owners.back().AllocateFor(att_weights, allocator, MatPadding::kOdd); } const size_t T_bytes = att_weights.ElementBytes(); @@ -149,9 +150,10 @@ void LayerWeightsPtrs::SplitAttW1() { // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. -void LayerWeightsPtrs::Fixup(std::vector& mat_owners) { +void LayerWeightsPtrs::Fixup(std::vector& mat_owners, + const Allocator& allocator) { // TODO(janwas): handle NUQ - InitAttWeights(mat_owners); + InitAttWeights(mat_owners, allocator); SplitW1(); SplitAttW1(); } @@ -223,13 +225,15 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Called by WeightsOwner::Fixup. void WeightsPtrs::Fixup(std::vector& mat_owners, - hwy::ThreadPool& pool) { + ThreadingContext& ctx) { + // TODO: use 1D parallel-for helper function + hwy::ThreadPool& pool = ctx.pools.Pool(); pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - GetLayer(layer)->Fixup(mat_owners); + GetLayer(layer)->Fixup(mat_owners, ctx.allocator); }); pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - VitLayer(layer)->Fixup(mat_owners); + VitLayer(layer)->Fixup(mat_owners, ctx.allocator); }); } @@ -260,12 +264,12 @@ enum class Mode { // Decides whether to read or map based on heuristics and user override. static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader, - const InferenceArgs& inference) { + const InferenceArgs& inference, + const Allocator& allocator) { Tristate to_bf16 = loader.to_bf16; Tristate map = loader.map; // Disable mapping if not padded to the base page size. - const Allocator& allocator = ThreadingContext::Get().allocator; if (file_bytes % allocator.BasePageBytes() != 0) { if (map == Tristate::kTrue) { // Only complain if explicitly requested. HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.", @@ -321,30 +325,31 @@ struct TensorToRead { // Allocates multiple in parallel and binds to NUMA nodes. static void AllocateAndBindAll(std::vector& tensors, const Mode mode, std::vector& owners, - hwy::ThreadPool& pool) { + ThreadingContext& ctx) { const size_t start = owners.size(); owners.resize(start + tensors.size()); - MMParallel parallel(ThreadingContext::Get()); + MMParallel parallel(ctx); // Allocate in parallel because faulting in large tensors is slow. - pool.Run(0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { - TensorToRead& tensor = tensors[task]; - MatPtr& mat = *tensor.mat; + ctx.pools.Pool().Run( + 0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { + TensorToRead& tensor = tensors[task]; + MatPtr& mat = *tensor.mat; - tensor.prev_type = mat.GetType(); - // We only care about MatMul inputs; skip F32 or small tensors. - if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) { - tensor.keep_type = true; - tensor.padding = MatPadding::kPacked; // single I/O for simplicity - } else if (mode == Mode::kReadBF16) { - mat.SetType(Type::kBF16); - } + tensor.prev_type = mat.GetType(); + // We only care about MatMul inputs; skip F32 or small tensors. + if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) { + tensor.keep_type = true; + tensor.padding = MatPadding::kPacked; // single I/O for simplicity + } else if (mode == Mode::kReadBF16) { + mat.SetType(Type::kBF16); + } - owners[start + task].AllocateFor(*tensor.mat, tensor.padding); - // TODO(janwas): MatMul outputs will later also be BF16. - BindB(*tensor.mat, sizeof(float), parallel); - }); + owners[start + task].AllocateFor(*tensor.mat, ctx.allocator, + tensor.padding); + BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel); + }); } // Mode == kMap @@ -482,7 +487,7 @@ static void ReadBatches(const BlobReader& reader, // Aborts on error. static void MapOrReadAll(std::vector& tensors, BlobReader& reader, Mode mode, std::vector& mat_owners, - hwy::ThreadPool& pool) { + ThreadingContext& ctx) { if (mode == Mode::kMap) { MapPtr mapped = reader.file().Map(); if (mapped) return MapAll(tensors, mapped); @@ -496,9 +501,11 @@ static void MapOrReadAll(std::vector& tensors, BlobReader& reader, { PROFILER_ZONE("Startup.Weights.Allocate"); // NOTE: this changes the stride of `mats`! - AllocateAndBindAll(tensors, mode, mat_owners, pool); + AllocateAndBindAll(tensors, mode, mat_owners, ctx); } + hwy::ThreadPool& pool = ctx.pools.Pool(); + if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool); const std::vector batches = @@ -510,7 +517,7 @@ void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader, const LoaderArgs& loader, const InferenceArgs& inference, std::vector& mat_owners, - hwy::ThreadPool& pool) { + ThreadingContext& ctx) { // List of tensors to read/map, and where from. std::vector tensors; @@ -529,13 +536,14 @@ void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader, HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name()); }); - const Mode mode = ChooseMode(reader.file_bytes(), loader, inference); + const Mode mode = + ChooseMode(reader.file_bytes(), loader, inference, ctx.allocator); - MapOrReadAll(tensors, reader, mode, mat_owners, pool); + MapOrReadAll(tensors, reader, mode, mat_owners, ctx); { PROFILER_ZONE("Startup.Fixup"); - Fixup(mat_owners, pool); + Fixup(mat_owners, ctx); } } diff --git a/gemma/weights.h b/gemma/weights.h index 63ed70d..d9978ff 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -29,7 +29,7 @@ #include "gemma/tensor_info.h" // TensorInfoRegistry #include "io/blob_store.h" // BlobWriter #include "util/mat.h" // MatPtr -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "util/threading_context.h" namespace gcpp { @@ -299,11 +299,12 @@ struct LayerWeightsPtrs { // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. - void Fixup(std::vector& mat_owners); + void Fixup(std::vector& mat_owners, const Allocator& allocator); private: // Copies att_weights from `attn_vec_einsum_w`. - void InitAttWeights(std::vector& mat_owners); + void InitAttWeights(std::vector& mat_owners, + const Allocator& allocator); // For FFN. Fast, only updates pointers. void SplitW1(); @@ -426,7 +427,7 @@ struct WeightsPtrs { // override for whether to map blobs or read them. void ReadFromBlobs(const ModelStore& model, BlobReader& reader, const LoaderArgs& loader, const InferenceArgs& inference, - std::vector& mat_owners, hwy::ThreadPool& pool); + std::vector& mat_owners, ThreadingContext& ctx); // Adds one blob for each tensor's data and returns all serialized MatPtr. std::vector AddTensorDataToWriter(BlobWriter& writer) const; @@ -434,7 +435,7 @@ struct WeightsPtrs { private: // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Called by ReadFromBlobs. - void Fixup(std::vector& mat_owners, hwy::ThreadPool& pool); + void Fixup(std::vector& mat_owners, ThreadingContext& ctx); }; // `WeightsPtrs` #undef TENSOR_ARGS diff --git a/io/blob_compare.cc b/io/blob_compare.cc index 7430ef3..bb25843 100644 --- a/io/blob_compare.cc +++ b/io/blob_compare.cc @@ -227,11 +227,12 @@ void ReadAndCompareBlobs(const Path& path1, const Path& path2) { BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos); BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos); - NestedPools& pools = ThreadingContext::Get().pools; + ThreadingArgs args; + ThreadingContext ctx(args); ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2, - pools); + ctx.pools); - CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools); + CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, ctx.pools); } } // namespace gcpp diff --git a/io/blob_store_test.cc b/io/blob_store_test.cc index ec8f231..b763428 100644 --- a/io/blob_store_test.cc +++ b/io/blob_store_test.cc @@ -36,7 +36,9 @@ class BlobStoreTest : public testing::Test {}; #endif TEST(BlobStoreTest, TestReadWrite) { - hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool(); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + hwy::ThreadPool& pool = ctx.pools.Pool(); static const std::array kOriginalData = {-1, 0, 3.14159, 2.71828}; @@ -92,7 +94,9 @@ TEST(BlobStoreTest, TestReadWrite) { // Ensures padding works for any number of random-sized blobs. TEST(BlobStoreTest, TestNumBlobs) { - hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool(); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + hwy::ThreadPool& pool = ctx.pools.Pool(); hwy::RandomState rng; for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) { diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 7f0a587..04d535e 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -84,19 +84,22 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { const Extents2D B_extents(N, K); // already transposed const Extents2D C_extents(M, N); - MatStorageT C_slow("c_slow_batch", C_extents, MatPadding::kOdd); - MatStorageT C("c_batch", C_extents, MatPadding::kOdd); + MatStorageT C_slow("c_slow_batch", C_extents, env.ctx.allocator, + MatPadding::kOdd); + MatStorageT C("c_batch", C_extents, env.ctx.allocator, MatPadding::kOdd); - MatStorageT add_storage("add", Extents2D(), MatPadding::kPacked); + MatStorageT add_storage("add", Extents2D(), env.ctx.allocator, + MatPadding::kPacked); if (add) { - add_storage = - GenerateMat(Extents2D(1, N), MatPadding::kPacked, pool); + add_storage = GenerateMat(Extents2D(1, N), env.ctx.allocator, + MatPadding::kPacked, pool); add_storage.SetScale(1.0f); } - MatStorageT a = GenerateMat(A_extents, MatPadding::kOdd, pool); - MatStorageT b_trans = - GenerateTransposedMat(B_extents, MatPadding::kOdd, pool); + MatStorageT a = + GenerateMat(A_extents, env.ctx.allocator, MatPadding::kOdd, pool); + MatStorageT b_trans = GenerateTransposedMat( + B_extents, env.ctx.allocator, MatPadding::kOdd, pool); const float* add_row = add ? add_storage.PackedScale1() : nullptr; @@ -151,10 +154,10 @@ void BenchAllMatMul() { return; } - ThreadingContext& ctx = ThreadingContext::Get(); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(), ctx.pools.PinString()); - MatMulEnv env(ctx); for (size_t batch_size : {1, 4, 128, 512}) { diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 4f0c94d..a461614 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -999,6 +999,8 @@ struct TestShortDotsT { const size_t N = hn::Lanes(d); const hn::ScalableTag df; // for CallDot + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); CompressWorkingSet work; std::mt19937 rng; rng.seed(12345); @@ -1009,14 +1011,14 @@ struct TestShortDotsT { // GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`, // hence they require padding to one vector. const size_t padded_num = hwy::RoundUpTo(num, N); - MatStorageT raw_w("raw_w", padded_num); - MatStorageT raw_v("raw_v", padded_num); - MatStorageT weights("weights", padded_num); + MatStorageT raw_w("raw_w", padded_num, ctx.allocator); + MatStorageT raw_v("raw_v", padded_num, ctx.allocator); + MatStorageT weights("weights", padded_num, ctx.allocator); const PackedSpan w = weights.Span(); - MatStorageT vectors("vectors", padded_num); + MatStorageT vectors("vectors", padded_num, ctx.allocator); const PackedSpan v = vectors.Span(); - MatStorageT bufs("bufs", num); + MatStorageT bufs("bufs", padded_num, ctx.allocator); double* HWY_RESTRICT buf = bufs.Row(0); for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { @@ -1097,14 +1099,12 @@ void TestAllDot() { constexpr size_t kMaxWorkers = 15; - // Reset with cap on workers because we only support `kMaxWorkers`. - ThreadingContext::ThreadHostileInvalidate(); + // Limit workers because we only support `kMaxWorkers`. ThreadingArgs threading_args; threading_args.max_packages = 1; threading_args.max_clusters = 1; threading_args.max_lps = kMaxWorkers - 1; - ThreadingContext::SetArgs(threading_args); - ThreadingContext& ctx = ThreadingContext::Get(); + ThreadingContext ctx(threading_args); { // ensure no profiler zones are active const hn::ScalableTag df; @@ -1116,9 +1116,11 @@ void TestAllDot() { constexpr size_t kReps = hn::AdjustedReps(40); const size_t num = 24 * 1024; - MatStorageT a("a", Extents2D(kMaxWorkers, num), MatPadding::kOdd); - MatStorageT b("b", Extents2D(kMaxWorkers, num), MatPadding::kOdd); - MatStorageT bufs("bufs", Extents2D(kMaxWorkers, num), + MatStorageT a("a", Extents2D(kMaxWorkers, num), ctx.allocator, + MatPadding::kOdd); + MatStorageT b("b", Extents2D(kMaxWorkers, num), ctx.allocator, + MatPadding::kOdd); + MatStorageT bufs("bufs", Extents2D(kMaxWorkers, num), ctx.allocator, MatPadding::kOdd); std::array all_stats; diff --git a/ops/gemma_matvec_test.cc b/ops/gemma_matvec_test.cc index 0ca58b7..e55539d 100644 --- a/ops/gemma_matvec_test.cc +++ b/ops/gemma_matvec_test.cc @@ -26,6 +26,7 @@ #include #include "util/mat.h" +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -68,10 +69,11 @@ FloatPtr SimpleMatVecAdd(const MatStorageT& mat, const FloatPtr& vec, template std::unique_ptr> GenerateMat(size_t offset, + const Allocator& allocator, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; const Extents2D extents(kOuter, kInner); - auto mat = std::make_unique>("TestMat", extents, + auto mat = std::make_unique>("TestMat", extents, allocator, MatPadding::kPacked); FloatPtr raw_mat = hwy::AllocateAligned(extents.Area()); HWY_ASSERT(raw_mat); @@ -109,10 +111,12 @@ void AssertClose(const FloatPtr& a, const FloatPtr& b) { } void TestMatVecAdd() { - hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads()); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + hwy::ThreadPool& pool = ctx.pools.Pool(); constexpr size_t kOuter = 128 * 3; constexpr size_t kInner = 128 * 5; - auto mat = GenerateMat(0, pool); + auto mat = GenerateMat(0, ctx.allocator, pool); FloatPtr vec = GenerateVec(0); FloatPtr add = GenerateVec(0); FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add); @@ -124,11 +128,13 @@ void TestMatVecAdd() { } void TestTwoMatVecAdd() { - hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads()); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + hwy::ThreadPool& pool = ctx.pools.Pool(); constexpr size_t kOuter = 128 * 3; constexpr size_t kInner = 128 * 5; - auto mat0 = GenerateMat(0, pool); - auto mat1 = GenerateMat(1, pool); + auto mat0 = GenerateMat(0, ctx.allocator, pool); + auto mat1 = GenerateMat(1, ctx.allocator, pool); FloatPtr vec = GenerateVec(0); FloatPtr add0 = GenerateVec(0); FloatPtr add1 = GenerateVec(1); @@ -145,10 +151,13 @@ void TestTwoMatVecAdd() { } void TestTwoOfsMatVecAddLoop() { - hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads()); + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + hwy::ThreadPool& pool = ctx.pools.Pool(); + constexpr size_t kOuter = 128 * 3; constexpr size_t kInner = 128 * 5; - auto mat = GenerateMat(0, pool); + auto mat = GenerateMat(0, ctx.allocator, pool); FloatPtr vec = GenerateVec(0); FloatPtr add0 = GenerateVec(0); FloatPtr add1 = GenerateVec(1); diff --git a/ops/matmul.h b/ops/matmul.h index fed3bc6..99477d3 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -53,8 +53,8 @@ constexpr size_t kNR = 4; // or less on ISAs with fewer registers, or for the last few rows of A. static constexpr size_t kMaxMR = 4; -// Mostly stateless, can be constructed on the fly by weights.cc, but captures -// the singleton ThreadingContext to reduce MatMul call overhead. +// Mostly stateless, can be constructed on the fly by weights.cc. Captures the +// the ThreadingContext to shorten call sites. class MMParallel { public: static constexpr size_t kMaxPackages = 4; @@ -251,7 +251,7 @@ class MMStorage { : // Per-worker copies of `partial` would be wasteful. We instead // allocate one instance of the maximum matrix extents because threads // write at false-sharing-free granularity. - partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), + partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), allocator, MatPadding::kOdd), // Same stride independent of the actual C.Cols() so we can pre-bind. partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { @@ -259,7 +259,7 @@ class MMStorage { // Must be padded, see `DoDecompressA`. parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { pkg_A_[pkg_idx].reset(new MatStorageT( - "pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd)); + "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); if (allocator.ShouldBind()) { const size_t node = parallel.Node(pkg_idx); diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 3885534..2ec77f5 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -91,14 +91,15 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const size_t cols = A.Cols(); const size_t B_rows = B.Rows(); // Round up for DecompressAndZeroPad. - MatStorageT a_batch("a_batch", A.Extents(), MatPadding::kOdd); - MatStorageT b_trans_batch("b_trans_batch", B.Extents(), - MatPadding::kOdd); - MatStorageT c_batch("c_batch", Extents2D(A.Rows(), B_rows), + MatStorageT a_batch("a_batch", A.Extents(), env.ctx.allocator, MatPadding::kOdd); + MatStorageT b_trans_batch("b_trans_batch", B.Extents(), + env.ctx.allocator, MatPadding::kOdd); + MatStorageT c_batch("c_batch", Extents2D(A.Rows(), B_rows), + env.ctx.allocator, MatPadding::kOdd); c_batch.AllocateAndAttachRowPtrs(env.row_ptrs); MatStorageT c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows), - MatPadding::kOdd); + env.ctx.allocator, MatPadding::kOdd); for (size_t m = 0; m < A.Rows(); ++m) { DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols); DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m), @@ -219,17 +220,21 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D C_extents(rows_ac, cols_bc); - MatStorageT A(GenerateMat(A_extents, MatPadding::kOdd, pool)); + MatStorageT A( + GenerateMat(A_extents, env.ctx.allocator, MatPadding::kOdd, pool)); // Must be packed because we call Span() on it. - MatStorageT BT( - GenerateTransposedMat(B_extents, MatPadding::kPacked, pool)); - MatStorageT C_slow("C_slow", C_extents, MatPadding::kOdd); - MatStorageT C("C", C_extents, MatPadding::kOdd); + MatStorageT BT(GenerateTransposedMat(B_extents, env.ctx.allocator, + MatPadding::kPacked, pool)); + MatStorageT C_slow("C_slow", C_extents, env.ctx.allocator, + MatPadding::kOdd); + MatStorageT C("C", C_extents, env.ctx.allocator, MatPadding::kOdd); C.AllocateAndAttachRowPtrs(env.row_ptrs); MatStorageT add_storage = - add ? GenerateMat(Extents2D(1, cols_bc), MatPadding::kPacked, pool) - : MatStorageT("add", Extents2D(), MatPadding::kPacked); + add ? GenerateMat(Extents2D(1, cols_bc), env.ctx.allocator, + MatPadding::kPacked, pool) + : MatStorageT("add", Extents2D(), env.ctx.allocator, + MatPadding::kPacked); add_storage.SetScale(1.0f); const float* add_row = add ? add_storage.PackedScale1() : nullptr; @@ -252,12 +257,11 @@ void TestTiny() { if (HWY_TARGET != first_target) return; for (size_t max_packages : {1, 2}) { - ThreadingContext::ThreadHostileInvalidate(); ThreadingArgs threading_args; threading_args.bind = Tristate::kTrue; threading_args.max_packages = max_packages; - ThreadingContext::SetArgs(threading_args); - MatMulEnv env(ThreadingContext::Get()); + ThreadingContext ctx(threading_args); + MatMulEnv env(ctx); NestedPools& pools = env.ctx.pools; if constexpr (GEMMA_DISABLE_TOPOLOGY) { @@ -291,11 +295,10 @@ void TestAllMatMul() { return; } - ThreadingContext::ThreadHostileInvalidate(); ThreadingArgs threading_args; threading_args.bind = Tristate::kTrue; - ThreadingContext::SetArgs(threading_args); - MatMulEnv env(ThreadingContext::Get()); + ThreadingContext ctx(threading_args); + MatMulEnv env(ctx); NestedPools& pools = env.ctx.pools; pools.MaybeStartSpinning(threading_args.spin); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 3a8132e..0c53805 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -1018,13 +1018,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( // Input has 4096 (64*64) rows, output has 256 (16*16) rows // Each output row is the average of a 4x4 block of input rows template -MatStorageT AvgPool4x4(MatStorageT& input) { +MatStorageT AvgPool4x4(MatStorageT& input, const Allocator& allocator) { const Extents2D extents = input.Extents(); // Input validation HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows // Create output with 256 rows and same number of columns const size_t out_rows = 256; // 16 * 16 = 256 output rows - MatStorageT result("pool4x4", Extents2D(out_rows, extents.cols), + MatStorageT result("pool4x4", Extents2D(out_rows, extents.cols), allocator, MatPadding::kOdd); const size_t input_dim = 64; // Input is 64×64 const size_t output_dim = 16; // Output is 16×16 diff --git a/ops/ops.h b/ops/ops.h index 73d9327..03b023b 100644 --- a/ops/ops.h +++ b/ops/ops.h @@ -26,9 +26,10 @@ namespace gcpp { static inline HWY_MAYBE_UNUSED MatStorageT CreateInvTimescale( - size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) { + const Allocator& allocator, size_t qkv_dim, bool half_rope, + double base_frequency = 10000.0) { const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; - MatStorageT inv_timescale("inv_timescale", rope_dim / 2); + MatStorageT inv_timescale("inv_timescale", rope_dim / 2, allocator); for (size_t dim = 0; dim < rope_dim / 2; ++dim) { const double freq_exponents = static_cast(2 * dim) / static_cast(rope_dim); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index d2cf821..da1e23c 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -347,10 +347,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( } void TestRopeAndMulBy() { + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); const ModelConfig config(Model::GEMMA2_9B, Type::kSFP, ChooseWrapping(Model::GEMMA2_9B)); const size_t dim_qkv = config.layer_configs[0].qkv_dim; - MatStorageT x("x", dim_qkv); + MatStorageT x("x", dim_qkv, ctx.allocator); std::mt19937 gen; gen.seed(0x12345678); @@ -364,13 +366,13 @@ void TestRopeAndMulBy() { const float qmul = AttentionActivations::ChooseQueryScale(config); constexpr float kmul = 1.0f; - MatStorageT qexpected("qexpected", dim_qkv); - MatStorageT qactual("qactual", dim_qkv); - MatStorageT kexpected("kexpected", dim_qkv); - MatStorageT kactual("kactual", dim_qkv); - MatStorageT kactual2("kactual2", dim_qkv); + MatStorageT qexpected("qexpected", dim_qkv, ctx.allocator); + MatStorageT qactual("qactual", dim_qkv, ctx.allocator); + MatStorageT kexpected("kexpected", dim_qkv, ctx.allocator); + MatStorageT kactual("kactual", dim_qkv, ctx.allocator); + MatStorageT kactual2("kactual2", dim_qkv, ctx.allocator); MatStorageT inv_timescale = CreateInvTimescale( - config.layer_configs[0].qkv_dim, + ctx.allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); // Assert VectorizedRope computation is same as regular rope at different pos. for (size_t pos = 1; pos < 500; pos++) { diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index 6f811e3..4872553 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -21,7 +21,7 @@ void PaliGemmaHelper::InitVit(const std::string& path) { image_tokens_ = std::make_unique( "image", Extents2D(config.vit_config.seq_len, config.model_dim), - MatPadding::kPacked); + env_->Env().ctx.allocator, MatPadding::kPacked); image_tokens_->AllocateAndAttachRowPtrs(env_->Env().row_ptrs); Image image; HWY_ASSERT(image.ReadPPM(path)); diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 3496858..238b546 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -186,7 +186,7 @@ class GemmaModel { image_tokens_.reset(new gcpp::ImageTokens( "image_tokens", gcpp::Extents2D(config.vit_config.seq_len, config.model_dim), - gcpp::MatPadding::kOdd)); + env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd)); gcpp::RuntimeConfig runtime_config = {.gen = &env_.MutableGen(), .verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(), diff --git a/util/mat.cc b/util/mat.cc index 44d62ec..f81767d 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -78,10 +78,10 @@ size_t Stride(MatPadding padding, size_t cols, size_t element_bytes, } } -void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { +void MatOwner::AllocateFor(MatPtr& mat, const Allocator& allocator, + MatPadding padding) { const bool is_nuq = mat.GetType() == Type::kNUQ; if (is_nuq) padding = MatPadding::kPacked; - const Allocator& allocator = ThreadingContext::Get().allocator; const size_t stride = Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes()); const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride; diff --git a/util/mat.h b/util/mat.h index 786351d..b0de72d 100644 --- a/util/mat.h +++ b/util/mat.h @@ -443,7 +443,7 @@ class MatOwner { // Allocates the type/extents indicated by `mat` and sets its pointer. // Ignores `padding` for NUQ tensors, which are always packed. // Thread-compatible, weights are allocated in parallel. - void AllocateFor(MatPtr& mat, MatPadding padding); + void AllocateFor(MatPtr& mat, const Allocator& allocator, MatPadding padding); private: AlignedPtr storage_; @@ -455,13 +455,14 @@ class MatOwner { template class MatStorageT : public MatPtrT { public: - MatStorageT(const char* name, Extents2D extents, MatPadding padding) + MatStorageT(const char* name, Extents2D extents, const Allocator& allocator, + MatPadding padding) : MatPtrT(name, extents) { - if (extents.Area() != 0) owner_.AllocateFor(*this, padding); + if (extents.Area() != 0) owner_.AllocateFor(*this, allocator, padding); } // Shorthand for 1D tensors: packing does not help, hence `kPacked`. - MatStorageT(const char* name, size_t cols) - : MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {} + MatStorageT(const char* name, size_t cols, const Allocator& allocator) + : MatStorageT(name, Extents2D(1, cols), allocator, MatPadding::kPacked) {} ~MatStorageT() = default; // Allow move for KVCache. @@ -472,5 +473,31 @@ class MatStorageT : public MatPtrT { MatOwner owner_; }; +// Helper for initializing members which are `MatStorageT`: avoids having to +// specify Extents2D and MatPadding at each call site. +class MatFactory { + public: + // The constructor captures all the necessary arguments. + MatFactory(const char* name, size_t rows, size_t cols, + const Allocator& allocator, MatPadding padding = MatPadding::kOdd) + : name_(name), + extents_(rows, cols), + allocator_(allocator), + padding_(padding) {} + + // Templated conversion so we do not have to specify the type in the + // member initializer. + template + operator MatStorageT() const { + return MatStorageT(name_.c_str(), extents_, allocator_, padding_); + } + + private: + const std::string name_; + Extents2D extents_; + const Allocator& allocator_; + MatPadding padding_; +}; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_ diff --git a/util/threading_context.cc b/util/threading_context.cc index 0cfdcb5..14eca95 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -15,60 +15,13 @@ #include "util/threading_context.h" -#include -#include // NOLINT - -#include "hwy/base.h" // HWY_ASSERT, HWY_UNLIKELY -#include "hwy/profiler.h" - namespace gcpp { -static ThreadingArgs s_args; -// Cannot use magic static because that does not support `Invalidate`, hence -// allocate manually. -static std::unique_ptr s_ctx; -static std::mutex s_ctx_mutex; - -/*static*/ void ThreadingContext::SetArgs(const ThreadingArgs& args) { - s_ctx_mutex.lock(); - HWY_ASSERT(!s_ctx); // Ensure not already initialized, else this is too late. - s_args = args; - s_ctx_mutex.unlock(); -} - -/*static*/ bool ThreadingContext::IsInitialized() { - s_ctx_mutex.lock(); - const bool initialized = !!s_ctx; - s_ctx_mutex.unlock(); - return initialized; -} - -/*static*/ ThreadingContext& ThreadingContext::Get() { - PROFILER_FUNC; - // We do not bother with double-checked locking because it requires an - // atomic pointer, but we prefer to use unique_ptr for simplicity. Also, - // callers can cache the result and call less often. - s_ctx_mutex.lock(); - if (HWY_UNLIKELY(!s_ctx)) { - s_ctx = std::make_unique(PrivateToken()); - } - s_ctx_mutex.unlock(); - return *s_ctx; -} - -/*static*/ void ThreadingContext::ThreadHostileInvalidate() { - // Deliberately avoid taking the lock so that tsan can warn if this is - // called concurrently with other calls to `Get`. - s_ctx.reset(); -} - -// WARNING: called with `s_ctx_mutex` held. Calling `SetArgs` or `Get` would -// deadlock. -ThreadingContext::ThreadingContext(ThreadingContext::PrivateToken) - : topology(BoundedSlice(s_args.skip_packages, s_args.max_packages), - BoundedSlice(s_args.skip_clusters, s_args.max_clusters), - BoundedSlice(s_args.skip_lps, s_args.max_lps)), - allocator(topology, s_args.bind != Tristate::kFalse), - pools(topology, allocator, s_args.max_threads, s_args.pin) {} +ThreadingContext::ThreadingContext(const ThreadingArgs& args) + : topology(BoundedSlice(args.skip_packages, args.max_packages), + BoundedSlice(args.skip_clusters, args.max_clusters), + BoundedSlice(args.skip_lps, args.max_lps)), + allocator(topology, args.bind != Tristate::kFalse), + pools(topology, allocator, args.max_threads, args.pin) {} } // namespace gcpp diff --git a/util/threading_context.h b/util/threading_context.h index e35d368..564ea90 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -85,43 +85,9 @@ class ThreadingArgs : public ArgsBase { } }; -// Lazily-initialized singleton with support for passing in arguments from -// `ThreadingArgs` and re-initializing with different arguments. -class ThreadingContext { - struct PrivateToken {}; // avoids constructing directly - - public: - // If not called, default arguments are used when `Get` initializes the - // singleton. Must not be called after `Get`, unless after a call to - // `ThreadHostileInvalidate`, because otherwise initialization already - // happened and the arguments would have no effect. Thread-safe, though this - // is expected to be called early in the program, before threading starts. - static void SetArgs(const ThreadingArgs& args); - - // Returns whether `Get()` has already been called, typically used to avoid - // calling `SetArgs` after that, because it would assert. - static bool IsInitialized(); - - // Returns a reference to the singleton after initializing it if necessary. - // When initializing, uses the args passed to `SetArgs`, or defaults. - // - // It is safe to call this concurrently with other `Get`, but not with - // `SetArgs`, because that will warn if called after this, nor with - // `ThreadHostileInvalidate`, because that will invalidate the reference which - // callers of this may still be using. Such usage only occurs in tests, - // hence we prefer not to pull `std::shared_ptr` into the interface. - // - // To reduce overhead, callers should cache the result and call less often. - static ThreadingContext& Get(); - - // Invalidates the singleton before or after a call to `Get`. This allows - // changing the arguments between tests. Callers must again call `Get` - // afterwards to obtain an instance. WARNING: must not be called concurrently - // with other calls to `Get` and usages of its return value. - // Also useful to suppress memory leak warnings in tests. - static void ThreadHostileInvalidate(); - - explicit ThreadingContext(PrivateToken); // only called via `Get`. +struct ThreadingContext { + // Expected to be called early in the program, before threading starts. + explicit ThreadingContext(const ThreadingArgs& args); BoundedTopology topology; Allocator allocator; diff --git a/util/threading_test.cc b/util/threading_test.cc index b6626ac..ac2746b 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -280,8 +280,6 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { } const double t1 = hwy::platform::Now(); -// TODO(janwas): enable after Highway update -#if 0 if (pool.AutoTuneComplete()) { hwy::Span cd = pool.AutoTuneCosts(); std::vector costs; @@ -308,10 +306,6 @@ std::vector MeasureForkJoin(hwy::ThreadPool& pool) { } else { HWY_WARN("Auto-tuning did not complete yet."); } -#else - (void)t0; - (void)t1; -#endif char cpu100[100]; static const bool have_stop = hwy::platform::HaveTimerStop(cpu100); @@ -383,7 +377,9 @@ TEST(ThreadingTest, BenchJoin) { } }; - NestedPools& pools = ThreadingContext::Get().pools; + ThreadingArgs threading_args; + ThreadingContext ctx(threading_args); + NestedPools& pools = ctx.pools; // Use last package because the main thread has been pinned to it. const size_t pkg_idx = pools.NumPackages() - 1;