From 868b01601fdd7c758086e02756965bbff9ae65c8 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 4 Nov 2024 07:47:49 -0800 Subject: [PATCH] Simpler MatMul interface, vocab types, Tristate for use_spinning Add Extents2D, Range2D vocab types Matmul uses ConstMat for inputs and RowPtr for output Move RowVectorBatch to basics.h Separate threading.cc Fix topology string: report cores not LPs, and #HT Move QStride/IsMHA into LayerConfig ImageTokens does not require make_unique. matmul_test: no longer require template args PiperOrigin-RevId: 692963605 --- BUILD.bazel | 11 +- CMakeLists.txt | 1 + backprop/optimize_test.cc | 4 +- compression/compress.h | 25 ++- compression/shared.h | 8 +- evals/benchmark_helper.cc | 4 +- gemma/activations.h | 45 ++-- gemma/configs.h | 7 + gemma/gemma-inl.h | 351 +++++++++++++------------------ gemma/gemma.cc | 13 +- gemma/gemma.h | 6 +- gemma/run.cc | 44 ++-- gemma/weights.h | 19 +- ops/dot_test.cc | 20 +- ops/matmul-inl.h | 309 +++++++++++++++------------- ops/matmul.h | 55 +---- ops/matmul_test.cc | 306 ++++++++++++++------------- ops/ops_test.cc | 2 +- paligemma/paligemma_test.cc | 13 +- util/allocator.cc | 19 +- util/allocator.h | 79 +++---- util/app.h | 10 +- util/args.h | 32 ++- util/basics.h | 205 +++++++++++++++++- util/threading.cc | 400 ++++++++++++++++++++++++++++++++++++ util/threading.h | 294 ++++---------------------- 26 files changed, 1311 insertions(+), 971 deletions(-) create mode 100644 util/threading.cc diff --git a/BUILD.bazel b/BUILD.bazel index 08a45b9..e5a7939 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -30,8 +30,11 @@ cc_library( cc_library( name = "threading", + srcs = ["util/threading.cc"], hdrs = ["util/threading.h"], deps = [ + ":basics", + # Placeholder for container detection, do not remove "@highway//:hwy", "@highway//:thread_pool", "@highway//:topology", @@ -173,7 +176,9 @@ cc_test( tags = ["hwy_ops_test"], deps = [ ":allocator", + ":basics", ":ops", + ":test_util", ":threading", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", @@ -280,6 +285,7 @@ cc_library( ":kv_cache", ":weights", ":threading", + "//compression:compress", "//compression:io", "//compression:sfp", "//paligemma:image", @@ -307,6 +313,7 @@ cc_library( name = "args", hdrs = ["util/args.h"], deps = [ + ":basics", "//compression:io", "@highway//:hwy", ], @@ -317,6 +324,7 @@ cc_library( hdrs = ["util/app.h"], deps = [ ":args", + ":basics", ":common", ":gemma_lib", ":threading", @@ -342,8 +350,6 @@ cc_library( "//compression:compress", "@highway//:hwy", "@highway//:nanobenchmark", - "@highway//:thread_pool", - "@highway//:topology", ], ) @@ -583,6 +589,7 @@ cc_test( }, deps = [ ":backprop", + ":basics", ":common", ":gemma_lib", ":optimizer", diff --git a/CMakeLists.txt b/CMakeLists.txt index 1990481..876da6c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,6 +101,7 @@ set(SOURCES util/args.h util/basics.h util/test_util.h + util/threading.cc util/threading.h ) diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 36d284b..6d83de0 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -33,13 +33,15 @@ #include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/weights.h" +#include "util/basics.h" #include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { TEST(OptimizeTest, GradientDescent) { - NestedPools pools(1, /*pin=*/0, BoundedSlice(0, 1), BoundedSlice(0, 1)); + NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1), + BoundedSlice(0, 1)); hwy::ThreadPool& pool = pools.Pool(); std::mt19937 gen(42); diff --git a/compression/compress.h b/compression/compress.h index 172c86a..9050d53 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -33,6 +33,7 @@ #include "compression/blob_store.h" #include "compression/io.h" #include "compression/shared.h" +#include "util/basics.h" // IWYU pragma: end_exports #include "util/allocator.h" #if COMPRESS_STATS @@ -62,7 +63,9 @@ class MatPtr { num_elements_(rows * cols), rows_(rows), cols_(cols), - ptr_(nullptr) {} + ptr_(nullptr) { + stride_ = cols; + } // Default is to leave all fields default-initialized. MatPtr() = default; virtual ~MatPtr(); @@ -85,7 +88,9 @@ class MatPtr { element_size_(key2.hi), num_elements_(key2.lo), rows_(key3.lo), - cols_(key3.hi) {} + cols_(key3.hi) { + stride_ = cols_; + } // Adds the contents entry to the table of contents. void AddToToc(std::vector& toc) const { @@ -137,6 +142,12 @@ class MatPtr { // Returns the number of columns in the 2-d array (inner dimension). size_t Cols() const { return cols_; } + Extents2D Extents() const { return Extents2D(rows_, cols_); } + + // Currently same as cols, but may differ in the future. This is the offset by + // which to advance pointers to the next row. + size_t Stride() const { return stride_; } + // Decoded elements should be multiplied by this to restore their original // range. This is required because SfpStream can only encode a limited range // of magnitudes. @@ -187,6 +198,8 @@ class MatPtr { // freed. The underlying memory is owned by a subclass or some external class // and must outlive this object. void* ptr_ = nullptr; + + size_t stride_; }; // MatPtrT adds a single template argument to MatPtr for an explicit type. @@ -288,7 +301,15 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { } } +template +ConstMat ConstMatFromWeights(const MatPtrT& m, size_t ofs = 0) { + ConstMat mat = MakeConstMat(const_cast(m.data()), m.Extents(), ofs); + mat.scale = m.scale(); + return mat; +} + // MatStorageT adds the actual data storage to MatPtrT. +// TODO: use Extents2D instead of rows and cols. template class MatStorageT : public MatPtrT { public: diff --git a/compression/shared.h b/compression/shared.h index 7d90959..40c8f1c 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -267,8 +267,12 @@ struct PackedSpan { // check the compressed count and ensure we have that many. const size_t required = CompressedArrayElements(packed_ofs + num_accessible); - HWY_DASSERT(num >= required); - (void)required; + if constexpr (HWY_IS_DEBUG_BUILD) { + if (num < required) { + HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed", + packed_ofs, num_accessible, required, num); + } + } } Packed* HWY_RESTRICT ptr; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 1af2f1a..8c84f96 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -229,12 +229,12 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, fprintf(stderr, "Date & Time : %s" // dt includes \n "CPU : %s\n" - "CPU topology : %s\n" + "CPU topology : %s, %s\n" "Instruction set : %s (%zu bits)\n" "Compiled config : %s\n" "Weight Type : %s\n" "EmbedderInput Type : %s\n", - dt, cpu100, pools.TopologyString(), + dt, cpu100, pools.TopologyString(), pools.PinString(), hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8, CompiledConfig(), StringFromType(loader.Info().weight), TypeName()); diff --git a/gemma/activations.h b/gemma/activations.h index 6b39854..3863325 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -72,18 +72,11 @@ struct Activations { size_t seq_len; size_t cache_pos_size = 0; - // Multi-Head Attention? - bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; } - - // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, - // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. - size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); } - static RowVectorBatch CreateInvTimescale(size_t qkv_dim, PostQKType post_qk) { const size_t rope_dim = post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim; - RowVectorBatch inv_timescale(1, rope_dim / 2); + RowVectorBatch inv_timescale(Extents2D(1, rope_dim / 2)); for (size_t dim = 0; dim < rope_dim / 2; ++dim) { const float freq_exponents = static_cast(2 * dim) / static_cast(rope_dim); @@ -100,29 +93,31 @@ struct Activations { const size_t ff_hidden_dim = layer_config.ff_hidden_dim; const size_t vocab_size = weights_config.vocab_size; - x = RowVectorBatch(batch_size, model_dim); - q = RowVectorBatch(batch_size, layer_config.heads * QStride()); + x = RowVectorBatch(Extents2D(batch_size, model_dim)); + q = RowVectorBatch( + Extents2D(batch_size, layer_config.heads * layer_config.QStride())); if (vocab_size > 0) { - logits = RowVectorBatch(batch_size, vocab_size); + logits = RowVectorBatch(Extents2D(batch_size, vocab_size)); } - pre_att_rms_out = RowVectorBatch(batch_size, model_dim); - att = RowVectorBatch(batch_size, - layer_config.heads * weights_config.seq_len); - att_out = RowVectorBatch(batch_size, - layer_config.heads * layer_config.qkv_dim); - att_sums = RowVectorBatch(batch_size, model_dim); + pre_att_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); + att = RowVectorBatch( + Extents2D(batch_size, layer_config.heads * weights_config.seq_len)); + att_out = RowVectorBatch( + Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim)); + att_sums = RowVectorBatch(Extents2D(batch_size, model_dim)); - bf_pre_ffw_rms_out = RowVectorBatch(batch_size, model_dim); - C1 = RowVectorBatch(batch_size, ff_hidden_dim); - C2 = RowVectorBatch(batch_size, ff_hidden_dim); - ffw_out = RowVectorBatch(batch_size, model_dim); + bf_pre_ffw_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); + C1 = RowVectorBatch(Extents2D(batch_size, ff_hidden_dim)); + C2 = RowVectorBatch(Extents2D(batch_size, ff_hidden_dim)); + ffw_out = RowVectorBatch(Extents2D(batch_size, model_dim)); if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { - griffin_x = RowVectorBatch(batch_size, model_dim); - griffin_y = RowVectorBatch(batch_size, model_dim); - griffin_gate_x = RowVectorBatch(batch_size, model_dim); - griffin_multiplier = RowVectorBatch(batch_size, model_dim); + griffin_x = RowVectorBatch(Extents2D(batch_size, model_dim)); + griffin_y = RowVectorBatch(Extents2D(batch_size, model_dim)); + griffin_gate_x = RowVectorBatch(Extents2D(batch_size, model_dim)); + griffin_multiplier = + RowVectorBatch(Extents2D(batch_size, model_dim)); } inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk); diff --git a/gemma/configs.h b/gemma/configs.h index 706eabb..f6a4245 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -119,6 +119,13 @@ enum class Model { struct LayerConfig { size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; } + // Multi-Head Attention? + bool IsMHA() const { return heads == kv_heads; } + + // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, + // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. + size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); } + size_t model_dim = 0; size_t griffin_dim = 0; size_t ff_hidden_dim = 0; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 028949c..c58f9a8 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -20,9 +20,9 @@ #include #include // std::min -#include #include +#include "compression/compress.h" #include "gemma/activations.h" #include "gemma/common.h" #include "gemma/configs.h" @@ -31,6 +31,7 @@ // Placeholder for internal test4, do not remove #include "paligemma/image.h" #include "util/allocator.h" +#include "util/basics.h" #include "util/threading.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -232,49 +233,49 @@ class GemmaAttention { // KV directly to KVCache. HWY_NOINLINE void ComputeQKV(const size_t num_interleaved) { PROFILER_ZONE("Gen.Attention.QKV"); - // For the computation of Q, K, and V, it is useful to remember that - // qkv_einsum_w has shape [(layer_config_.heads + layer_config_.kv_heads * - // 2), kKQVDim, layer_config_.model_dim] and q_stride_ = - // layer_config_.qkv_dim * (is_mha_ ? 3 : 1); + const size_t model_dim = layer_config_.model_dim; + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t heads = layer_config_.heads; + const size_t kv_heads = layer_config_.kv_heads; const auto pre_att_rms_out = - ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim); - const auto w_q1 = layer_weights_.qkv_einsum_w.data() == nullptr - ? ConstMat(layer_weights_.qkv_einsum_w1.data(), - layer_config_.model_dim) - : ConstMat(layer_weights_.qkv_einsum_w.data(), - layer_config_.model_dim); - const auto w_q2 = - layer_weights_.qkv_einsum_w.data() == nullptr - ? ConstMat(layer_weights_.qkv_einsum_w2.data(), - layer_config_.model_dim) - : ConstMat(layer_weights_.qkv_einsum_w.data(), - layer_config_.model_dim, layer_config_.model_dim, - layer_config_.heads * layer_config_.qkv_dim * - layer_config_.model_dim); - MatMul( - num_interleaved, pre_att_rms_out, w_q1, - layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env, - MutableMat(activations_.q.All(), layer_config_.heads * q_stride_)); + ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out); + auto w_q1 = layer_weights_.qkv_einsum_w.data() + ? ConstMatFromWeights(layer_weights_.qkv_einsum_w) + : ConstMatFromWeights(layer_weights_.qkv_einsum_w1); + // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, + // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows. + // We must shrink to the actual size because MatMul verifies + // `B.extents.rows == C.Cols()`. If MHA, `QStride() == 3 * qkv_dim` and all + // rows are used. Otherwise, `QStride() == qkv_dim` and KV will be + // computed in the second MatMul. + const size_t w1_rows = heads * layer_config_.QStride(); + w_q1.ShrinkRows(w1_rows); + MatMul(pre_att_rms_out, w_q1, + /*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q)); if (is_mha_) { // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. } else { + auto w_q2 = layer_weights_.qkv_einsum_w.data() + ? ConstMatFromWeights(layer_weights_.qkv_einsum_w, + w1_rows * model_dim) + : ConstMatFromWeights(layer_weights_.qkv_einsum_w2); + // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). + const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; + w_q2.ShrinkRows(w_rows_kv_cols); + // Single query and no wraparound means we can use a matmul and write // directly into the KV cache with a stride of cache_pos_size_. if (num_queries_ == 1 && queries_pos_[0] + num_tokens_ <= div_seq_len_.GetDivisor()) { const size_t kv_ofs = queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_; - // KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs of - // (k, v). float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; - MatMul( - num_tokens_, pre_att_rms_out, w_q2, - layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, - activations_.env, - MutableMat(kv, layer_config_.kv_heads * 2 * layer_config_.qkv_dim, - cache_pos_size_)); + RowPtrF kv_rows(kv, w_rows_kv_cols); + kv_rows.SetStride(cache_pos_size_); + MatMul(pre_att_rms_out, w_q2, + /*add=*/nullptr, activations_.env, kv_rows); } else { // Proceed row by row because there will be wraparound. for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; @@ -288,40 +289,34 @@ class GemmaAttention { const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - // KV structure is [k, v, k, v, ....] = layer_config_.kv_heads pairs - // of (k, v). - if (layer_weights_.qkv_einsum_w.data() == nullptr) { - MatVec(layer_weights_.qkv_einsum_w2, 0, - layer_config_.kv_heads * 2 * layer_config_.qkv_dim, - layer_config_.model_dim, x, kv, pool_); + if (layer_weights_.qkv_einsum_w.data()) { + MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim, + w_rows_kv_cols, model_dim, x, kv, pool_); } else { - MatVec(layer_weights_.qkv_einsum_w, - layer_config_.heads * layer_config_.qkv_dim * - layer_config_.model_dim, - layer_config_.kv_heads * 2 * layer_config_.qkv_dim, - layer_config_.model_dim, x, kv, pool_); + MatVec(layer_weights_.qkv_einsum_w2, 0, // + w_rows_kv_cols, model_dim, x, kv, pool_); } } } - } + } // !is_mha_ // Apply positional encodings for K (and copy KV to cache if MHA). - pool_.Run(0, layer_config_.kv_heads * num_interleaved, + pool_.Run(0, kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.kv_heads; - const size_t interleaved_idx = task / layer_config_.kv_heads; + const size_t head = task % kv_heads; + const size_t interleaved_idx = task / kv_heads; const size_t query_idx = interleaved_idx % num_queries_; const size_t batch_idx = interleaved_idx / num_queries_; const size_t pos = queries_pos_[query_idx] + batch_idx; const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_ + - head * layer_config_.qkv_dim * 2; + head * qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; const float* HWY_RESTRICT mha_kv = activations_.q.Batch(interleaved_idx) + head * q_stride_ + - layer_config_.qkv_dim; + qkv_dim; // Copy from `q` if MHA, or apply in-place. PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, @@ -329,9 +324,8 @@ class GemmaAttention { // If MHA, also copy V into KVCache. if (is_mha_) { - hwy::CopyBytes(mha_kv + layer_config_.qkv_dim, - kv + layer_config_.qkv_dim, - layer_config_.qkv_dim * sizeof(*kv)); + hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim, + qkv_dim * sizeof(*kv)); } }); } @@ -463,27 +457,14 @@ class GemmaAttention { HWY_DASSERT(layer_weights_.att_weights.data() != nullptr); HWY_DASSERT(activations_.att_out.All() != nullptr); HWY_DASSERT(activations_.att_sums.All() != nullptr); - if (layer_weights_.layer_config.softmax_attn_output_biases) { - MatMul( - num_interleaved, - ConstMat(activations_.att_out.All(), - layer_config_.heads * layer_config_.qkv_dim), - ConstMat(layer_weights_.att_weights.data(), - layer_config_.heads * layer_config_.qkv_dim), - layer_weights_.att_weights.scale(), - layer_weights_.attention_output_biases.data_scale1(), - activations_.env, - MutableMat(activations_.att_sums.All(), layer_config_.model_dim)); - } else { - MatMul( - num_interleaved, - ConstMat(activations_.att_out.All(), - layer_config_.heads * layer_config_.qkv_dim), - ConstMat(layer_weights_.att_weights.data(), - layer_config_.heads * layer_config_.qkv_dim), - layer_weights_.att_weights.scale(), nullptr, activations_.env, - MutableMat(activations_.att_sums.All(), layer_config_.model_dim)); - } + + const float* add = + layer_weights_.layer_config.softmax_attn_output_biases + ? layer_weights_.attention_output_biases.data_scale1() + : nullptr; + MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out), + ConstMatFromWeights(layer_weights_.att_weights), add, + activations_.env, RowPtrFromBatch(activations_.att_sums)); } public: @@ -524,13 +505,13 @@ class GemmaAttention { num_queries_(queries_pos.size()), num_tokens_(num_tokens), layer_(layer), - q_stride_(activations.QStride()), + layer_config_(layer_weights->layer_config), + q_stride_(layer_config_.QStride()), cache_layer_size_(layer_weights->layer_config.CacheLayerSize()), cache_pos_size_(activations.cache_pos_size), - is_mha_(activations.IsMHA()), + is_mha_(layer_config_.IsMHA()), activations_(activations), layer_weights_(*layer_weights), - layer_config_(layer_weights->layer_config), div_seq_len_(div_seq_len), kv_caches_(kv_caches), pool_(activations.env.Pool()) { @@ -552,6 +533,7 @@ class GemmaAttention { const size_t num_queries_; const size_t num_tokens_; const size_t layer_; + const LayerConfig& layer_config_; const size_t q_stride_ = 0; const size_t cache_layer_size_ = 0; const size_t cache_pos_size_ = 0; @@ -559,7 +541,6 @@ class GemmaAttention { Activations& activations_; const LayerWeightsPtrs& layer_weights_; - const LayerConfig& layer_config_; const hwy::Divisor& div_seq_len_; const KVCaches& kv_caches_; hwy::ThreadPool& pool_; @@ -601,17 +582,13 @@ class VitAttention { // Computes Q, K, V for all heads, stored in activations_.q. HWY_NOINLINE void ComputeQKV() { PROFILER_ZONE("Gen.VitAttention.QKV"); - const auto y = - ConstMat(activations_.pre_att_rms_out.All(), layer_config_.model_dim); auto& qkv = activations_.q; HWY_ASSERT(qkv.BatchSize() == num_tokens_); - HWY_ASSERT(qkv.Len() == layer_config_.heads * 3 * layer_config_.qkv_dim); - MatMul( - num_tokens_, y, - ConstMat(layer_weights_.vit.qkv_einsum_w.data_scale1(), - layer_config_.model_dim), - /*scale=*/1.0f, layer_weights_.vit.qkv_einsum_b.data_scale1(), - activations_.env, MutableMat(qkv.All(), qkv.Len())); + HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); + MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out), + ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w), + layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env, + RowPtrFromBatch(qkv)); } HWY_NOINLINE void DotSoftmaxWeightedSum() { @@ -658,17 +635,13 @@ class VitAttention { HWY_NOINLINE void SumHeads() { PROFILER_ZONE("Gen.VitAttention.SumHeads"); auto* bias = layer_weights_.vit.attn_out_b.data_scale1(); - auto att_out = ConstMat(activations_.att_out.All(), - layer_config_.heads * layer_config_.qkv_dim); - auto att_weights = ConstMat(layer_weights_.vit.attn_out_w.data_scale1(), - layer_config_.heads * layer_config_.qkv_dim); - auto att_sums = - MutableMat(activations_.att_sums.All(), layer_config_.model_dim); // att_weights and att_out are concatenated heads, each of length // layer_config_.qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. - MatMul(num_tokens_, att_out, att_weights, /*scale=*/1.0f, - bias, activations_.env, att_sums); + auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); + auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w); + auto att_sums = RowPtrFromBatch(activations_.att_sums); + MatMul(att_out, att_weights, bias, activations_.env, att_sums); } public: @@ -720,125 +693,94 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, PROFILER_ZONE("Gen.FFW"); const size_t model_dim = layer_weights->layer_config.model_dim; const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim; - const bool add_bias = layer_weights->layer_config.ff_biases; using WeightType = T; HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); - // Define slightly more readable names for the weights and activations. - const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim); - Mat w1; - const float* bias1 = nullptr; - Mat w2; - const float* bias2 = nullptr; - float scale = 1.0f; - Mat w_output; - const float* output_bias = nullptr; - float output_scale = 1.0f; - auto hidden_activations = MutableMat(activations.C1.All(), ffh_hidden_dim); - auto multiplier = MutableMat(activations.C2.All(), ffh_hidden_dim); - auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim); + const bool add_bias = layer_weights->layer_config.ff_biases; + const float* bias1 = + add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr; + const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr; + const float* output_bias = + add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr; - // For some of the weights and activations, it depends on the config where to - // get them from or whether to use them at all. - bias1 = layer_weights->ffw_gating_biases.data_scale1(); - bias2 = bias1 + ffh_hidden_dim; - output_bias = layer_weights->ffw_output_biases.data_scale1(); - w1 = layer_weights->gating_einsum_w.data() == nullptr - ? ConstMat(layer_weights->gating_einsum_w1.data(), model_dim) - : ConstMat(layer_weights->gating_einsum_w.data(), model_dim); - w2 = layer_weights->gating_einsum_w.data() == nullptr - ? ConstMat(layer_weights->gating_einsum_w2.data(), model_dim) - : ConstMat(layer_weights->gating_einsum_w.data(), model_dim, - model_dim, model_dim * ffh_hidden_dim); - scale = layer_weights->gating_einsum_w.data() == nullptr - ? layer_weights->gating_einsum_w1.scale() - : layer_weights->gating_einsum_w.scale(); - w_output = ConstMat(layer_weights->linear_w.data(), ffh_hidden_dim); - output_scale = layer_weights->linear_w.scale(); + // Define slightly more readable names for the weights and activations. + const auto x = + ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); + + auto hidden_activations = RowPtrFromBatch(activations.C1); + auto multiplier = RowPtrFromBatch(activations.C2); + auto ffw_out = RowPtrFromBatch(activations.ffw_out); + + // gating_einsum_w holds two half-matrices. We plan to change the importer to + // avoid this confusion by splitting into gating_einsum_w1 and + // gating_einsum_w2. + const bool split = !!layer_weights->gating_einsum_w.data(); + auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w) + : ConstMatFromWeights(layer_weights->gating_einsum_w1); + auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w, + model_dim * ffh_hidden_dim) + : ConstMatFromWeights(layer_weights->gating_einsum_w2); + if (split) { + // Ensure that B.Extents().row matches C.Cols() because MatMul checks that. + w1.ShrinkRows(ffh_hidden_dim); + w2.ShrinkRows(ffh_hidden_dim); + } + auto w_output = ConstMatFromWeights(layer_weights->linear_w); // Compute the hidden layer activations. - if (add_bias) { - MatMul(num_interleaved, x, w1, scale, bias1, - activations.env, hidden_activations); - MatMul(num_interleaved, x, w2, scale, bias2, - activations.env, multiplier); - } else { - MatMul(num_interleaved, x, w1, scale, bias1, - activations.env, hidden_activations); - MatMul(num_interleaved, x, w2, scale, bias2, - activations.env, multiplier); - } + MatMul(x, w1, bias1, activations.env, hidden_activations); + MatMul(x, w2, bias2, activations.env, multiplier); // Activation (Gelu) and maybe multiply by gate. Store activations in act. - Activation(layer_weights->layer_config.activation, hidden_activations.ptr, - multiplier.ptr, ffh_hidden_dim * num_interleaved); + Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), + multiplier.Row(0), ffh_hidden_dim * num_interleaved); // Hidden layer -> output layer. - if (add_bias) { - MatMul(num_interleaved, ConstMat(hidden_activations), - w_output, output_scale, output_bias, - activations.env, ffw_out); - } else { - MatMul(num_interleaved, ConstMat(hidden_activations), - w_output, output_scale, output_bias, - activations.env, ffw_out); - } + auto activations_mat = MakeConstMat( + hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim)); + + MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out); } +// Same as FFWNoVit, but with different layer_weights members and no second +// gating matrix. template HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, const LayerWeightsPtrs* layer_weights) { PROFILER_ZONE("Gen.FFW"); - const size_t model_dim = layer_weights->layer_config.model_dim; const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim; - const bool add_bias = layer_weights->layer_config.ff_biases; using WeightType = typename LayerWeightsPtrs::WeightF32OrBF16; HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); - // Define slightly more readable names for the weights and activations. - const auto x = ConstMat(activations.bf_pre_ffw_rms_out.All(), model_dim); - Mat w1; - const float* bias1 = nullptr; - float scale = 1.0f; - Mat w_output; - const float* output_bias = nullptr; - float output_scale = 1.0f; - auto hidden_activations = MutableMat(activations.C1.All(), ff_hidden_dim); - auto multiplier = MutableMat(activations.C2.All(), ff_hidden_dim); - auto ffw_out = MutableMat(activations.ffw_out.All(), model_dim); + const bool add_bias = layer_weights->layer_config.ff_biases; + const float* bias1 = + add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr; + const float* output_bias = + add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr; - // For some of the weights and activations, it depends on the config where to - // get them from or whether to use them at all. - w1 = ConstMat(layer_weights->vit.linear_0_w.data_scale1(), model_dim); - bias1 = layer_weights->vit.linear_0_b.data_scale1(); - multiplier.ptr = nullptr; - w_output = - ConstMat(layer_weights->vit.linear_1_w.data_scale1(), ff_hidden_dim); - output_bias = layer_weights->vit.linear_1_b.data_scale1(); + // Define slightly more readable names for the weights and activations. + const auto x = + ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); + + auto hidden_activations = RowPtrFromBatch(activations.C1); + auto ffw_out = RowPtrFromBatch(activations.ffw_out); + + auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w); + auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w); // Compute the hidden layer activations. - if (add_bias) { - MatMul(num_interleaved, x, w1, scale, bias1, - activations.env, hidden_activations); - } else { - MatMul(num_interleaved, x, w1, scale, bias1, - activations.env, hidden_activations); - } + MatMul(x, w1, bias1, activations.env, hidden_activations); - // Activation (Gelu) and maybe multiply by gate. Store activations in act. - Activation(layer_weights->layer_config.activation, hidden_activations.ptr, - multiplier.ptr, ff_hidden_dim * num_interleaved); + // Activation (Gelu), store in act. + RowPtrF multiplier = RowPtrF(nullptr, 0); + Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), + multiplier.Row(0), ff_hidden_dim * num_interleaved); // Hidden layer -> output layer. - if (add_bias) { - MatMul(num_interleaved, ConstMat(hidden_activations), - w_output, output_scale, output_bias, - activations.env, ffw_out); - } else { - MatMul(num_interleaved, ConstMat(hidden_activations), - w_output, output_scale, output_bias, - activations.env, ffw_out); - } + auto activations_mat = MakeConstMat( + hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim)); + + MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out); } // `batch_idx` indicates which row of `x` to write to. @@ -853,7 +795,7 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, // Image tokens just need to be copied. if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) { hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx), - x.Len() * sizeof(x.Const()[0])); + x.Cols() * sizeof(x.Const()[0])); return; } @@ -942,7 +884,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos, // the Big Vision codebase. See // github.com/google-research/big_vision/blob/main/big_vision/models/vit.py // TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and -// try mergig this with TransformerLayer. +// try merging this with TransformerLayer. template HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, const LayerWeightsPtrs* layer_weights, @@ -953,7 +895,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, auto& x = activations.x; HWY_DASSERT(x.BatchSize() == num_tokens); - HWY_DASSERT(x.Len() == model_dim); + HWY_DASSERT(x.Cols() == model_dim); // y = nn.LayerNorm()(x) // y ~ pre_att_rms_out @@ -1106,7 +1048,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, const size_t patch_size = patch_width * patch_width * 3; HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == patch_size * model_dim); - HWY_DASSERT(activations.x.Len() == model_dim); + HWY_DASSERT(activations.x.Cols() == model_dim); std::vector> image_patches(seq_len); for (size_t i = 0; i < seq_len; ++i) { image_patches[i] = hwy::AllocateAligned(patch_size); @@ -1118,11 +1060,11 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // This could be done as one MatMul like: // RowVectorBatch image_patches(kSeqLen, kPatchSize); // [Get patches] - // MatMul( - // kVitSeqLen, ConstMat(image_patches.All(), kPatchSize), - // ConstMat(weights.vit_img_embedding_kernel.data_scale1(), kPatchSize), - // /*scale=*/1.0f, weights.vit_img_embedding_bias.data_scale1(), - // activations.env, MutableMat(activations.x.All(), kVitModelDim)); + // MatMul( + // MatFromBatch(kVitSeqLen, image_patches), + // MatFromWeights(weights.vit_img_embedding_kernel), + // weights.vit_img_embedding_bias.data_scale1(), activations.env, + // RowPtrF(activations.x.All(), kVitModelDim)); // However, MatMul currently requires that // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 // which is not the case here. We should relax that requirement on MatMul and @@ -1163,11 +1105,10 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, activations.x.All(), vit_model_dim); // Apply head embedding into image_tokens of size of the LLM kModelDim. - MatMul( - num_tokens, ConstMat(activations.x.All(), vit_model_dim), - ConstMat(weights.vit_img_head_kernel.data_scale1(), vit_model_dim), - /*scale=*/1.0f, weights.vit_img_head_bias.data_scale1(), activations.env, - MutableMat(image_tokens.All(), weights.weights_config.model_dim)); + MatMul(ConstMatFromBatch(num_tokens, activations.x), + ConstMatFromWeights(weights.vit_img_head_kernel), + weights.vit_img_head_bias.data_scale1(), activations.env, + RowPtrFromBatch(image_tokens)); } // Generates one token for each query. `queries_token` is the previous token @@ -1299,7 +1240,6 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, const QueriesPos& queries_prefix_end, const size_t query_idx_start, const KVCaches& kv_caches, TimingInfo& timing_info) { - const size_t model_dim = model.Config().model_dim; const size_t vocab_size = model.Config().vocab_size; const ModelWeightsPtrs& weights = *model.GetWeightsOfType(); @@ -1387,11 +1327,10 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, { PROFILER_ZONE("Gen.EmbeddingMatmul"); // Compute logits from last layer activations. - MatMul( - num_queries, ConstMat(activations.x.All(), model_dim), - ConstMat(weights.embedder_input_embedding.data(), model_dim), - weights.embedder_input_embedding.scale(), /*add=*/nullptr, - activations.env, MutableMat(activations.logits.All(), vocab_size)); + MatMul(ConstMatFromBatch(num_queries, activations.x), + ConstMatFromWeights(weights.embedder_input_embedding), + /*add=*/nullptr, activations.env, + RowPtrFromBatch(activations.logits)); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { diff --git a/gemma/gemma.cc b/gemma/gemma.cc index aa01b7f..739bddb 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -35,7 +35,6 @@ #include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" -#include "hwy/profiler.h" // also uses SIMD namespace gcpp { @@ -119,12 +118,12 @@ struct GenerateImageTokensT { void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, KVCache& kv_cache, TimingInfo& timing_info) { - if (runtime_config.use_spinning) pools_.StartSpinning(); + pools_.MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight( runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info); - if (runtime_config.use_spinning) pools_.StopSpinning(); + pools_.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, @@ -141,23 +140,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, QueriesPos(prefix_end_vec.data(), prefix_end_vec.size()); } - if (runtime_config.use_spinning) pools_.StartSpinning(); + pools_.MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight( runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end, kv_caches, pools_, timing_info); - if (runtime_config.use_spinning) pools_.StopSpinning(); + pools_.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens) { - if (runtime_config.use_spinning) pools_.StartSpinning(); + pools_.MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight(runtime_config, image, image_tokens, pools_); - if (runtime_config.use_spinning) pools_.StopSpinning(); + pools_.MaybeStopSpinning(runtime_config.use_spinning); } // Non-template functions moved from gemma-inl.h to avoid ODR violations. diff --git a/gemma/gemma.h b/gemma/gemma.h index cee99f3..5df319f 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -121,7 +121,11 @@ struct RuntimeConfig { const ImageTokens *image_tokens = nullptr; // Whether to use thread spinning to reduce barrier synchronization latency. - bool use_spinning = true; + // Mutable so we can change kDefault to kTrue/kFalse during Generate, because + // RuntimeConfig is const there and is not passed to the Gemma ctor. This + // default decision is likely sufficient because it is based on whether + // threads are successfully pinned. + mutable Tristate use_spinning = Tristate::kDefault; // End-of-sequence token. int eos_id = EOS_ID; diff --git a/gemma/run.cc b/gemma/run.cc index 1a5231c..2c62bdb 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -16,7 +16,6 @@ // Command line text interface to gemma. #include -#include #include #include #include @@ -79,8 +78,8 @@ std::string GetPrompt(std::istream& input, int verbosity, } // The main Read-Eval-Print Loop. -void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, - int verbosity, const AcceptFunc& accept_token, +void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, + const InferenceArgs& args, const AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // across turns @@ -92,17 +91,18 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, const bool have_image = !args.image_file.path.empty(); Image image; - std::unique_ptr image_tokens; + ImageTokens image_tokens; if (have_image) { - image_tokens = std::make_unique( - model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim); + image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, + model.GetModelConfig().model_dim)); HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(image.ReadPPM(args.image_file.path)); image.Resize(); - RuntimeConfig runtime_config = {.verbosity = verbosity, .gen = &gen}; + RuntimeConfig runtime_config = { + .verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin}; double image_tokens_start = hwy::platform::Now(); - model.GenerateImageTokens(runtime_config, image, *image_tokens); - if (verbosity >= 1) { + model.GenerateImageTokens(runtime_config, image, image_tokens); + if (app.verbosity >= 1) { double image_tokens_duration = hwy::platform::Now() - image_tokens_start; fprintf(stderr, "\n\n[ Timing info ] Image token generation took: %d ms\n", @@ -122,7 +122,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, abs_pos = 0; InitGenerator(args, gen); } - if (verbosity >= 2) { + if (app.verbosity >= 2) { std::cout << "\n[ End ]\n"; } } else { @@ -133,7 +133,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, if (tokens_generated_this_turn == prompt_size + 1) { // first token of response token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (verbosity >= 1) { + if (app.verbosity >= 1) { std::cout << "\n\n"; } } @@ -144,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, while (true) { // Loop until user quits. tokens_generated_this_turn = 0; - std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line); + std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line); if (!std::cin) return; // If !eot_line.empty(), we append \n, so only look at the first 2 chars. if (prompt_string.size() >= 2 && prompt_string[0] == '%') { @@ -171,18 +171,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, } } - TimingInfo timing_info = {.verbosity = verbosity}; - RuntimeConfig runtime_config = { - .verbosity = verbosity, - .gen = &gen, - .stream_token = stream_token, - .accept_token = accept_token, - }; + TimingInfo timing_info = {.verbosity = app.verbosity}; + RuntimeConfig runtime_config = {.verbosity = app.verbosity, + .gen = &gen, + .stream_token = stream_token, + .accept_token = accept_token, + .use_spinning = app.spin}; args.CopyTo(runtime_config); size_t prefix_end = 0; if (have_image) { - runtime_config.image_tokens = image_tokens.get(); - prompt.insert(prompt.begin(), image_tokens->BatchSize(), 0); + runtime_config.image_tokens = &image_tokens; + prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0); prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. @@ -237,8 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { std::cout << "\n" << instructions << "\n"; } - ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(), - app.eot_line); + ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line); } } // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index 60e9d13..ce2df43 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -95,11 +95,11 @@ struct LayerWeightsPtrs { config.model_dim}, .qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads), config.qkv_dim}, - .linear_0_w = {"linear_0_w", config.model_dim, - config.ff_hidden_dim}, - .linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim}, - .linear_1_w = {"linear_1_w", config.ff_hidden_dim, + .linear_0_w = {"linear_0_w", config.ff_hidden_dim, config.model_dim}, + .linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim}, + .linear_1_w = {"linear_1_w", config.model_dim, + config.ff_hidden_dim}, .linear_1_b = {"linear_1_b", 1, config.model_dim}, .layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim}, .layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim}, @@ -349,14 +349,13 @@ struct ModelWeightsPtrs { vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), - vit_img_embedding_kernel( - "img_emb_kernel", - config.patch_width * config.patch_width * 3, - config.vit_model_dim), + vit_img_embedding_kernel("img_emb_kernel", + config.patch_width * config.patch_width * 3, + config.vit_model_dim), vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim), vit_img_head_bias("img_head_bias", 1, config.model_dim), - vit_img_head_kernel("img_head_kernel", config.vit_model_dim, - config.model_dim), + vit_img_head_kernel("img_head_kernel", config.model_dim, + config.vit_model_dim), scale_names(config.scale_names), weights_config(config) { c_layers.reserve(config.layer_configs.size()); diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 1970a8b..7fb8514 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1011,14 +1011,14 @@ struct TestShortDotsT { // hence they require padding to one vector. const size_t padded_num = hwy::RoundUpTo(num, N); const size_t packed_num = CompressedArrayElements(num); - RowVectorBatch raw_w(1, padded_num); - RowVectorBatch raw_v(1, padded_num); - RowVectorBatch weights(1, packed_num); + RowVectorBatch raw_w(Extents2D(1, padded_num)); + RowVectorBatch raw_v(Extents2D(1, padded_num)); + RowVectorBatch weights(Extents2D(1, packed_num)); const PackedSpan w(weights.Batch(0), packed_num); - RowVectorBatch vectors(1, num); + RowVectorBatch vectors(Extents2D(1, num)); const PackedSpan v(vectors.Batch(0), num); - RowVectorBatch bufs(1, num); + RowVectorBatch bufs(Extents2D(1, num)); double* HWY_RESTRICT buf = bufs.Batch(0); for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { @@ -1107,11 +1107,11 @@ void TestAllDot() { constexpr size_t kReps = hn::AdjustedReps(40); const size_t num = 24 * 1024; - NestedPools pools(kMaxWorkers - 1, /*pin=*/1, BoundedSlice(0, 1), - BoundedSlice(0, 1)); - RowVectorBatch a(kMaxWorkers, num); - RowVectorBatch b(kMaxWorkers, num); - RowVectorBatch bufs(kMaxWorkers, num); + NestedPools pools(kMaxWorkers - 1, /*pin=*/Tristate::kDefault, + BoundedSlice(0, 1), BoundedSlice(0, 1)); + RowVectorBatch a(Extents2D(kMaxWorkers, num)); + RowVectorBatch b(Extents2D(kMaxWorkers, num)); + RowVectorBatch bufs(Extents2D(kMaxWorkers, num)); std::array all_stats; pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) { diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 20eb916..8646f79 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -16,8 +16,9 @@ #include #include -#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used #include "ops/matmul.h" // IWYU pragma: export +#include "util/allocator.h" +#include "util/basics.h" // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE) @@ -30,7 +31,7 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" -#include "hwy/contrib/math/math-inl.h" +#include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -53,38 +54,20 @@ constexpr size_t kRegCols = 4; // generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0). constexpr size_t kRegRows = kRegCols; -// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are -// more efficient than f32 * f32 + f32 because they process twice as many lanes -// at a time. Any combination of A and B can be bf16: activations may already be -// bf16, and weights can be decompressed to bf16. -// -// The corresponding op is `ReorderWidenMulAccumulate`, and it is always -// supported, but only useful if it returns a single vector of pairwise sums -// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReorderWidenMulAccumulate` -// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep -// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be -// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B -// to bf16 if the native op is available. This will actually demote f32 -// activations to bf16. Otherwise, we decompress to f32 and use normal FMA. -using MulT = hwy::If; - -// Loads two vectors at a time with element type MulT from a row of transposed -// B. Called in a loop over col_ab. No bounds checking because `kRow` is -// actually from B columns, which we checked is a multiple of `kRegCols`. +// Loads two vectors at a time with element type hn::TFromD from a row of +// transposed B. Called in a loop over col_ab. No bounds checking because +// `kRow` is from B columns, which we checked is a multiple of `kRegCols`. template class BRow { static_assert(kRow < kRegRows); // which unrolled instance we are public: - BRow(const Mat& B, size_t row_b, size_t cols_c) - // B.cols * C.cols is the total number of elements, required for - // PackedSpan::BoundsCheck. - : B_(MakeSpan(B.ptr, B.ofs + B.cols * cols_c)), - B_ofs_(B.Row(row_b + kRow)) {} + BRow(const ConstMat& B, size_t row_b) + : B_(MakeSpan(B.ptr, B.ofs + B.Extents().Area())), + B_ofs_(B.Row(HWY_MIN(row_b + kRow, B.Extents().rows - 1))) {} - template > - HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const { - static_assert(hwy::IsSame, MulT>()); + template > + HWY_INLINE void Load2(DR d, size_t col_ab, VR& b0, VR& b1) const { Decompress2(d, B_, B_ofs_ + col_ab, b0, b1); } @@ -93,11 +76,11 @@ class BRow { const size_t B_ofs_; }; -// Loads *two* row vectors from A via `Decompress2`, multiplies element-wise -// with `kRegRows` x 2 row vectors from transposed B, and adds them to -// `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of -// the terms of the dot products that make up the MatMul result at `r,c`. -// No-op for the bottom-most tile where kRow >= kNumRows. +// Loads *two* row vectors from A via `Decompress2`, widens to f32, multiplies +// element-wise with `kRegRows` x 2 row vectors from transposed B, and adds +// them to `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a +// subset of the terms of the dot products that make up the MatMul result at +// `r,c`. No-op for the bottom-most rows whose `kRow >= kNumRows`. // // This approach is atypical because it requires a horizontal sum, for which we // introduce a fast and new(?) vector-length agnostic 'transpose', see @@ -107,22 +90,24 @@ class BRow { // - `Decompress2` decompresses two vectors at a time; // - B is column-major, so unit-stride SIMD loads return a column, not values // from different columns, i.e. a row. -// Both could be fixed in a packing stage, which is not implemented yet, and -// might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is -// important for bf16 performance and incompatible with the conventional -// approach, because its pairwise adds would add together unrelated terms. -// By contrast, pairwise adds are fine when our C lanes are the terms of a -// single dot product, which can be reordered or pre-reduced. +// - `ReorderWidenMulAccumulate` is important for bf16 performance, but its +// pairwise adds would add together unrelated terms. +// The first two could be fixed in a packing stage, which is not implemented +// yet, and might not be necessary otherwise. The third seems a fundamental +// mismatch. However, pairwise adds are fine in our setting because C lanes are +// the terms of a single dot product, which can be reordered or pre-reduced. template class ALoadAccumulate { - static_assert(kRow < kRegRows); // which unrolled instance we are - public: - ALoadAccumulate(const Mat& A, size_t row_ac, size_t batch_size) - // A.cols * batch_size is the total number of elements, required for - // PackedSpan::BoundsCheck. - : A_(MakeSpan(A.ptr, A.ofs + A.cols * batch_size)), - A_ofs_(A.Row(row_ac + kRow)) {} + static_assert(kRow < kRegRows); // which unrolled instance we are + // `First` and `Next` handle a single row of A, so the horizontal sums of + // their `C0..3` are the (partial) dot products for 4 consecutive values in + // one row of C. + static_assert(kRegCols == 4); + + ALoadAccumulate(const ConstMat& A, size_t row_ac) + : A_(MakeSpan(A.ptr, A.ofs + A.Extents().Area())), + A_ofs_(A.Row(HWY_MIN(row_ac + kRow, A.Extents().rows - 1))) {} // First iteration, col_ab = 0: initialize C0..3 instead of updating them. template , HWY_IF_F32_D(DM)> @@ -161,20 +146,27 @@ class ALoadAccumulate { Decompress2(dm, A_, A_ofs_, a0, a1); const DF df; - VF unused_sum1 = hn::Zero(df); static_assert(kRegCols == 4); C0 = hn::WidenMulPairwiseAdd(df, a0, b00); C1 = hn::WidenMulPairwiseAdd(df, a0, b10); C2 = hn::WidenMulPairwiseAdd(df, a0, b20); C3 = hn::WidenMulPairwiseAdd(df, a0, b30); - C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); - C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); - C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); - C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); - - // Ensure sum1 was indeed unused. - HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); + if constexpr (HWY_NATIVE_DOT_BF16) { + // Native ReorderWidenMulAccumulate adds to C0..3 for free. + VF unused_sum1 = hn::Zero(df); + C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); + C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); + C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); + C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); + // Ensure sum1 was indeed unused. + HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); + } else { + C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01)); + C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11)); + C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21)); + C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31)); + } } } @@ -217,20 +209,31 @@ class ALoadAccumulate { Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); const DF df; - hn::Vec unused_sum1 = hn::Zero(df); static_assert(kRegCols == 4); - C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1); - C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1); - C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1); - C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1); - C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); - C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); - C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); - C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); - - // Ensure sum1 was indeed unused. - HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); + if constexpr (HWY_NATIVE_DOT_BF16) { + // Native ReorderWidenMulAccumulate adds to C0..3 for free. + VF unused_sum1 = hn::Zero(df); + C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1); + C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1); + C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1); + C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1); + C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); + C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); + C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); + C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); + // Ensure sum1 was indeed unused. + HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); + } else { + C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a0, b00)); + C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a0, b10)); + C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a0, b20)); + C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a0, b30)); + C0 = hn::Add(C0, hn::WidenMulPairwiseAdd(df, a1, b01)); + C1 = hn::Add(C1, hn::WidenMulPairwiseAdd(df, a1, b11)); + C2 = hn::Add(C2, hn::WidenMulPairwiseAdd(df, a1, b21)); + C3 = hn::Add(C3, hn::WidenMulPairwiseAdd(df, a1, b31)); + } } } @@ -356,116 +359,113 @@ class AddHorizontalSums { // Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a // *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c). // TODO: loop over sections instead of full rows and accumulate into `tile_c`. +// `buf` is 16 vectors of thread-local storage. template -HWY_INLINE void MatMulTile(const size_t batch_size, const Mat& A, - const Mat& B, const size_t row_ac, - const size_t row_b_col_c, const float scale, - const float* HWY_RESTRICT add, - float* HWY_RESTRICT buf, const Mat& C) { - // For 'decompressing' A and B into BF16 or float. - const hn::ScalableTag dm; - using VM = hn::Vec; - const size_t NM = hn::Lanes(dm); +HWY_INLINE void MatMulTile(const ConstMat& A, const size_t row_ac, + const ConstMat& B, const size_t row_b_col_c, + const float scale, const float* HWY_RESTRICT add, + float* HWY_RESTRICT buf, const RowPtr& C) { + // Decompress A and B to which type, which will then be widened to f32, + // multiplied, added once into f32, then promoted to f64 and accumulated. + // NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are + // more efficient than f32 * f32 + f32 because they process twice as many + // lanes at a time. If available, we definitely want to use them. Otherwise, + // bf16 is still worthwhile if A (activations) are bf16: SFP weights are + // cheaper to decode to bf16, relative to the minor extra cost of promoting + // bf16 when multiplying. However, if A is f32, demoting to bf16 can be + // expensive unless we also have native bf16 dot. + using Raw = hwy::If(), BF16, float>; + const hn::ScalableTag dr; + using VR = hn::Vec; + const size_t NR = hn::Lanes(dr); + + const Range1D cols_ab(0, A.Extents().cols); + HWY_DASSERT(row_ac + kNumRows <= A.Extents().rows); + HWY_DASSERT(row_b_col_c + kNumRows <= B.Extents().rows); + HWY_DASSERT(cols_ab.end() % (2 * NR) == 0); static_assert(kRegRows == 4); - const BRow<0, MatTB> b_row0(B, row_b_col_c, C.cols); - const BRow<1, MatTB> b_row1(B, row_b_col_c, C.cols); - const BRow<2, MatTB> b_row2(B, row_b_col_c, C.cols); - const BRow<3, MatTB> b_row3(B, row_b_col_c, C.cols); + const BRow<0, MatTB> b_row0(B, row_b_col_c); + const BRow<1, MatTB> b_row1(B, row_b_col_c); + const BRow<2, MatTB> b_row2(B, row_b_col_c); + const BRow<3, MatTB> b_row3(B, row_b_col_c); - const ALoadAccumulate<0, MatTA> a_row0(A, row_ac, batch_size); - const ALoadAccumulate<1, MatTA> a_row1(A, row_ac, batch_size); - const ALoadAccumulate<2, MatTA> a_row2(A, row_ac, batch_size); - const ALoadAccumulate<3, MatTA> a_row3(A, row_ac, batch_size); + const ALoadAccumulate<0, MatTA> a_row0(A, row_ac); + const ALoadAccumulate<1, MatTA> a_row1(A, row_ac); + const ALoadAccumulate<2, MatTA> a_row2(A, row_ac); + const ALoadAccumulate<3, MatTA> a_row3(A, row_ac); - const hn::Repartition df; + const hn::Repartition df; using VF = hn::Vec; VF C00, C01, C02, C03; VF C10, C11, C12, C13; VF C20, C21, C22, C23; VF C30, C31, C32, C33; + size_t col_ab = cols_ab.begin(); { // First iteration initializes the `Crc` vectors. - VM b00, b01, b10, b11, b20, b21, b30, b31; - b_row0.Load2(dm, /*col_ab=*/0, b00, b01); - b_row1.Load2(dm, /*col_ab=*/0, b10, b11); - b_row2.Load2(dm, /*col_ab=*/0, b20, b21); - b_row3.Load2(dm, /*col_ab=*/0, b30, b31); + VR b00, b01, b10, b11, b20, b21, b30, b31; + b_row0.Load2(dr, col_ab, b00, b01); + b_row1.Load2(dr, col_ab, b10, b11); + b_row2.Load2(dr, col_ab, b20, b21); + b_row3.Load2(dr, col_ab, b30, b31); - a_row0.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + a_row0.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, C00, C01, C02, C03); - a_row1.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + a_row1.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, C10, C11, C12, C13); - a_row2.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + a_row2.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, C20, C21, C22, C23); - a_row3.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + a_row3.template First(dr, b00, b01, b10, b11, b20, b21, b30, b31, C30, C31, C32, C33); + col_ab += 2 * NR; } - // `2 * NM` per iteration because `Load2` returns two vectors. + // `2 * NR` per iteration because `Load2` returns two vectors. HWY_UNROLL(1) - for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) { - VM b00, b01, b10, b11, b20, b21, b30, b31; - b_row0.Load2(dm, col_ab, b00, b01); - b_row1.Load2(dm, col_ab, b10, b11); - b_row2.Load2(dm, col_ab, b20, b21); - b_row3.Load2(dm, col_ab, b30, b31); + for (; col_ab < cols_ab.end(); col_ab += 2 * NR) { + VR b00, b01, b10, b11, b20, b21, b30, b31; + b_row0.Load2(dr, col_ab, b00, b01); + b_row1.Load2(dr, col_ab, b10, b11); + b_row2.Load2(dr, col_ab, b20, b21); + b_row3.Load2(dr, col_ab, b30, b31); - a_row0.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + a_row0.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C00, C01, C02, C03); - a_row1.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + a_row1.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C10, C11, C12, C13); - a_row2.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + a_row2.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C20, C21, C22, C23); - a_row3.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + a_row3.template Next(dr, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C30, C31, C32, C33); } // TODO: hoist into outer loop. - float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c; - InitC(add, row_b_col_c, C_tile, C.stride); + float* HWY_RESTRICT C_tile = C.Row(row_ac) + row_b_col_c; + InitC(add, row_b_col_c, C_tile, C.Stride()); AddHorizontalSums()(df, scale, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33, - buf, C_tile, C.stride); + buf, C_tile, C.Stride()); } -// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. -// -// `A` is a row-major matrix of shape `(batch_size, A.cols)`. -// `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of -// rows in the original B, and `C.cols` the number of columns in the original B. -// -// `scale` allows expanding the smaller range of `SfpStream` to the original -// values. When `A` and/or `B` are from CompressedArray, `scale` should be the -// product of their `.scale()` values, otherwise 1.0f. -// -// If `kAdd` is true, the row-vector `add` is added to each row of `C`, -// otherwise `add` is ignored and can be nullptr. A scale for `add` is not -// supported, so make sure its scale is 1. -// -// `C` is a row-major matrix of size `(batch_size, C.cols)`. -// -// Updates 4x4 tiles of C in parallel using a work-stealing thread pool. -// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k. -// Must not be called concurrently with the same `env`. template -HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, - const Mat& B, const float scale, - const float* HWY_RESTRICT add, MatMulEnv& env, - const Mat& C) { +HWY_NOINLINE void MatMulImpl(const ConstMat& A, const ConstMat& B, + const float* HWY_RESTRICT add, MatMulEnv& env, + const RowPtr& C) { // PROFILER_ZONE("Matmul"); - HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty()); - HWY_DASSERT(A.cols == B.cols); + HWY_DASSERT(A.Extents().cols == B.Extents().cols); + const size_t batch_size = A.Extents().rows; + HWY_DASSERT(C.Cols() % kRegCols == 0); + HWY_DASSERT(C.Stride() >= C.Cols()); + HWY_DASSERT(B.Extents().rows == C.Cols()); - // Must be a multiple of two vectors because we Decompress2. - HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0); - HWY_DASSERT(C.cols % kRegCols == 0); + const float scale = A.scale * B.scale; // We currently write C directly, which touches more memory than fits in L3. // TODO: add another level of loops to finish L3-sized pieces of C at a time. const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); - const size_t tilesX = C.cols / kRegCols; + const size_t tilesX = C.Cols() / kRegCols; env.Pool().Run( 0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR { @@ -481,24 +481,45 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, HWY_DASSERT(num_rows != 0); switch (num_rows) { case 1: - MatMulTile<1, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, - add, buf, C); + MatMulTile<1, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); break; case 2: - MatMulTile<2, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, - add, buf, C); + MatMulTile<2, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); break; case 3: - MatMulTile<3, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, - add, buf, C); + MatMulTile<3, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); break; default: - MatMulTile<4, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale, - add, buf, C); + MatMulTile<4, kAdd>(A, row_ac, B, row_b_col_c, scale, add, buf, C); } }); } +// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. +// +// `A` is a row-major matrix and `B` is transposed. Its `B.Extents().cols`, +// which must match `A.Extents().cols`, is the number of rows in the original B. +// +// If `add` is non-null, the row-vector `add` is added to each row of `C`. +// A scale for `add` is not supported, so make sure its scale is 1. +// +// `C` is a row-major matrix of size `(A.rows, C.Cols())` with support for +// arbitrary strides. +// +// Updates 4x4 tiles of C in parallel using a work-stealing thread pool. +// Typically `A.rows` is 1..512, `A.Extents().cols` and `B.Extents().rows` are +// 3k or 24k. Must not be called concurrently with the same `env`. +template +HWY_NOINLINE void MatMul(const ConstMat& A, const ConstMat& B, + const float* HWY_RESTRICT add, MatMulEnv& env, + const RowPtr& C) { + if (add) { + MatMulImpl(A, B, add, env, C); + } else { + MatMulImpl(A, B, nullptr, env, C); + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index c643062..2eff81f 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -19,73 +19,22 @@ #include // IWYU pragma: begin_exports +#include "util/basics.h" #include "util/threading.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: end_exports -#include "util/allocator.h" // RowVectorBatch #include "hwy/per_target.h" // VectorBytes namespace gcpp { -// Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be -// const or non-const. Create via ConstMat/MutableMat. -// TODO(rays): Replace with MatPtr and get rid of stride, which is only != cols -// in one place. -template -struct Mat { - bool NotEmpty() const { - return ptr != nullptr && cols != 0 && stride >= cols; - } - size_t Row(size_t r) const { return ofs + stride * r; } - - T* HWY_RESTRICT ptr; - size_t cols; - - // elements between rows, which is typically the same as `cols`. - size_t stride; - - // Offset to add to `ptr`; separate because T=NuqStream does not support - // pointer arithmetic. - size_t ofs; -}; - -template -Mat MutableMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride, - size_t ofs = 0) { - return Mat{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs}; -} - -template -Mat ConstMat(const T* HWY_RESTRICT ptr, size_t cols, size_t stride, - size_t ofs = 0) { - return Mat{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs}; -} - -template -Mat ConstMat(Mat mat) { - return ConstMat(mat.ptr, mat.cols, mat.stride, mat.ofs); -} - -template -Mat MutableMat(T* HWY_RESTRICT ptr, size_t cols) { - return MutableMat(ptr, cols, cols); -} - -template -Mat ConstMat(const T* HWY_RESTRICT ptr, size_t cols) { - return ConstMat(ptr, cols, cols); -} - // Allocations and threads, shared across MatMul calls. class MatMulEnv { public: MatMulEnv() : pools_(nullptr) {} explicit MatMulEnv(NestedPools& pools) : pools_(&pools) { const size_t N = hwy::VectorBytes() / sizeof(float); - buf_ = RowVectorBatch(pools.MaxWorkers(), 16 * N); + buf_ = RowVectorBatch(Extents2D(pools.MaxWorkers(), 16 * N)); } RowVectorBatch& Buf() { return buf_; } diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 3b6c7bd..8d36acd 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -32,6 +32,7 @@ #include "compression/compress.h" #include "util/allocator.h" +#include "util/basics.h" #include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -55,19 +56,23 @@ namespace HWY_NAMESPACE { using FloatPtr = hwy::AlignedFreeUniquePtr; +template +using MatStoragePtr = std::unique_ptr>; + // Generates inputs: deterministic, within max SfpStream range. -template >> -MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { +template +MatStoragePtr GenerateMat(const Extents2D extents, + hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - auto mat = std::make_unique>("test", kRows, kCols); + auto mat = + std::make_unique>("mat", extents.rows, extents.cols); FloatPtr content = hwy::AllocateAligned(mat->NumElements()); HWY_ASSERT(content); - const float scale = SfpStream::kMax / (mat->NumElements() + offset); - pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) { - for (size_t j = 0; j < kCols; j++) { - content[i * kCols + j] = - static_cast((i * kCols + j + offset) * scale); + const float scale = SfpStream::kMax / (mat->NumElements()); + pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { + for (size_t c = 0; c < extents.cols; c++) { + content[r * extents.cols + c] = + static_cast(r * extents.cols + c) * scale; } }); @@ -76,185 +81,173 @@ MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { return mat; } -template >> -MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) { +// extents describes the transposed matrix. +template +MatStoragePtr GenerateTransposedMat(const Extents2D extents, + hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - MatPtr mat = std::make_unique>("test", kCols, kRows); + auto mat = + std::make_unique>("trans", extents.rows, extents.cols); FloatPtr content = hwy::AllocateAligned(mat->NumElements()); - const float scale = SfpStream::kMax / (mat->NumElements() + offset); - pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) { - for (size_t j = 0; j < kCols; j++) { - content[j * kRows + i] = - static_cast((i * kCols + j + offset) * scale); + const float scale = SfpStream::kMax / (mat->NumElements()); + pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { + for (size_t c = 0; c < extents.cols; c++) { + content[r * extents.cols + c] = + static_cast(c * extents.rows + r) * scale; } }); CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); - // Arbitrary value, different from 1, must match GenerateMatHeap. + // Arbitrary value, different from 1, must match GenerateMat. mat->set_scale(0.6f); return mat; } -template >> -MatPtr GenerateZeroMat(hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - auto mat = std::make_unique>("Array", kRows, kCols); - FloatPtr content = hwy::AllocateAligned(mat->NumElements()); - HWY_ASSERT(content); - - pool.Run(0, kRows, [&](const size_t i, size_t thread) { - hwy::ZeroBytes(&content[i * kCols], kCols * sizeof(content[0])); - }); - - CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); - mat->set_scale(1.2f); // Arbitrary value, different from 1. - return mat; -} - // Returns 1-norm, used for estimating tolerable numerical differences. -double MaxColAbsSum(const float* HWY_RESTRICT a, size_t rows, size_t cols) { +double MaxColAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) { double max_col_abs_sum = 0.0; - for (size_t c = 0; c < cols; c++) { + for (size_t c = 0; c < extents.cols; c++) { double col_abs_sum = 0.0; - for (size_t r = 0; r < rows; r++) { - col_abs_sum += hwy::ScalarAbs(a[r * cols + c]); + for (size_t r = 0; r < extents.rows; r++) { + col_abs_sum += hwy::ScalarAbs(a[r * extents.cols + c]); } max_col_abs_sum = HWY_MAX(max_col_abs_sum, col_abs_sum); } return max_col_abs_sum; } +// B is already transposed. template -void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b, - const MatTA* HWY_RESTRICT pa, - const MatTB* HWY_RESTRICT pb_trans, - const float* HWY_RESTRICT expected_c, - const float* HWY_RESTRICT actual_c) { +void AssertClose(const ConstMat& A, const ConstMat& B, + const RowPtrF& C_slow, const RowPtrF& C) { const hn::ScalableTag df; - const size_t num_a = rows_ac * cols_ab; - const size_t num_b = cols_c_rows_b * cols_ab; + const size_t num_a = A.extents.Area(); + const size_t num_b = B.extents.Area(); HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad - const size_t num_c = rows_ac * cols_c_rows_b; FloatPtr a = hwy::AllocateAligned(num_a); FloatPtr b_trans = hwy::AllocateAligned(num_b); HWY_ASSERT(a && b_trans); - DecompressAndZeroPad(df, MakeSpan(pa, num_a), 0, a.get(), num_a); - DecompressAndZeroPad(df, MakeSpan(pb_trans, num_b), 0, b_trans.get(), num_b); + HWY_ASSERT(A.ofs == 0 && B.ofs == 0); + DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a); + DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b); - const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) * - MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab); + const double norm = MaxColAbsSum(a.get(), A.Extents()) * + MaxColAbsSum(b_trans.get(), B.Extents()); // Dot(float,BF16) rounds both to BF16. using RefType = hwy::If() && IsF32(), float, BF16>; const double epsilon = hwy::ConvertScalarTo(hwy::Epsilon()); const double tolerance = 200.0 * norm * epsilon; - for (size_t idx = 0; idx < num_c; idx++) { - const double expected_value = expected_c[idx]; - const double actual_value = actual_c[idx]; + for (size_t r = 0; r < A.extents.rows; r++) { + const float* expected_row = C_slow.Row(r); + const float* actual_row = C.Row(r); + for (size_t c = 0; c < B.extents.rows; c++) { + const double expected_value = static_cast(expected_row[c]); + const double actual_value = static_cast(actual_row[c]); - if (!(expected_value - tolerance <= actual_value && - actual_value <= expected_value + tolerance)) { - fprintf( - stderr, - "expected[%lu]: %f, actual[%lu]: %f, norm %f eps %E tolerance %f\n", - idx, expected_value, idx, actual_value, norm, epsilon, tolerance); - HWY_ASSERT(0); + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf( + stderr, + "(%zu,%zu): expected %f, actual %f, norm %f eps %E tolerance %f\n", + r, c, expected_value, actual_value, norm, epsilon, tolerance); + } } } } +// B is already transposed. template -HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, - const MatTA* HWY_RESTRICT a, - const MatTB* HWY_RESTRICT b_trans, const float scale, +HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, const float* HWY_RESTRICT add_row, MatMulEnv& env, - float* HWY_RESTRICT out) { + const RowPtrF& C) { // MatTA can be any Packed except NuqStream because it uses pointer // arithmetic, because it is the second argument to Dot, which does not // support a v_ofs. static_assert(sizeof(MatTA) >= sizeof(BF16), "A matrix must be BF16/f32"); + const float scale = A.scale * B.scale; const hn::ScalableTag df; // lane type is ignored const PackedSpan b_span = - MakeSpan(b_trans, cols_a_rows_b * cols_bc); + MakeSpan(B.ptr, B.ofs + B.extents.Area()); + const Extents2D C_extents(A.extents.rows, C.Cols()); StaticPartitionRowsAndCols( - env.Pools(), rows_ac, cols_bc, sizeof(MatTB), - [&](size_t /*node*/, hwy::ThreadPool& pool, - const size_t /*worker_offset*/, const size_t row_begin, - const size_t row_end, const size_t col_begin, const size_t col_end) { - pool.Run(row_begin, row_end, - [&](const uint64_t row, size_t /*thread*/) { - for (size_t col = col_begin; col < col_end; ++col) { - const float add = add_row ? add_row[col] : 0.0f; - out[row * cols_bc + col] = - scale * Dot(df, b_span, col * cols_a_rows_b, - a + row * cols_a_rows_b, cols_a_rows_b) + - add; - } - }); + env.Pools(), C_extents, sizeof(MatTB), + [&](const Range2D& C_range, const TaskLocation& loc) { + loc.cluster.Run( + C_range.rows.begin(), C_range.rows.end(), + [&](const uint64_t row, size_t /*thread*/) { + float* HWY_RESTRICT C_row = C.Row(row); + for (size_t row_b_col_c : C_range.cols) { + const float add = add_row ? add_row[row_b_col_c] : 0.0f; + C_row[row_b_col_c] = + add + scale * Dot(df, b_span, row_b_col_c * B.extents.cols, + A.ptr + A.Row(row), A.extents.cols); + } + }); }); } -void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b, - size_t cols_bc, double elapsed) { - const size_t num_b = cols_a_rows_b * cols_bc; +void PrintSpeed(const char* algo, const Extents2D& A_extents, + const Extents2D& B_extents, double elapsed) { + const size_t num_b = B_extents.Area(); // 2x because of FMA. fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo, - elapsed, 2 * 1E-9 * rows_ac * num_b / elapsed); + elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed); } -template -void TestMatMul(MatMulEnv& env) { +template +void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, + MatMulEnv& env) { hwy::ThreadPool& pool = env.Pool(); - const bool want_bench = kColsBC > 2000; // avoid spam for small matrices + const bool want_bench = cols_bc > 2000; // avoid spam for small matrices fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", - kRowsAC, kColsARowsB, kColsBC, kAdd, TypeName(), + rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), TypeName()); - std::unique_ptr> a = - GenerateMat(0, pool); - std::unique_ptr> b_trans = - GenerateTransposedMat(0, pool); - FloatPtr c = hwy::AllocateAligned(kRowsAC * kColsBC); - HWY_ASSERT(c); + const Extents2D A_extents(rows_ac, cols_a_rows_b); + const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed + const Extents2D C_extents(rows_ac, cols_bc); - const float scale = a->scale() * b_trans->scale(); - std::unique_ptr> add; - if (kAdd) { - add = GenerateMat(0, pool); - add->set_scale(1.0f); + MatStoragePtr a = GenerateMat(A_extents, pool); + MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); + RowVectorBatch c_slow_batch(C_extents); + RowVectorBatch c_batch(C_extents); + HWY_ASSERT(a && b_trans); + + std::unique_ptr> add_storage; + if (add) { + add_storage = GenerateMat(Extents2D(1, cols_bc), pool); + HWY_ASSERT(add_storage); + add_storage->set_scale(1.0f); } - std::unique_ptr> c_slow = - GenerateZeroMat(pool); + const auto A = ConstMatFromWeights(*a); + const auto B = ConstMatFromWeights(*b_trans); + const float* add_row = add ? add_storage->data_scale1() : nullptr; + const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch); + const RowPtrF C = RowPtrFromBatch(c_batch); + const double start_slow = hwy::platform::Now(); - MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale, - kAdd ? add->data() : nullptr, env, c_slow->data()); + MatMulSlow(A, B, add_row, env, C_slow); if (want_bench) { - PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC, + PrintSpeed("MatMulSlow", A_extents, B_extents, hwy::platform::Now() - start_slow); } double min_elapsed = hwy::HighestValue(); for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) { const double start_tiled = hwy::platform::Now(); - MatMul(kRowsAC, ConstMat(a->data(), kColsARowsB), - ConstMat(b_trans->data(), kColsARowsB), scale, - kAdd ? add->data_scale1() : nullptr, env, - MutableMat(c.get(), kColsBC)); + MatMul(A, B, add_row, env, C); min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled); } if (want_bench) { - PrintSpeed("MatMul", kRowsAC, kColsARowsB, kColsBC, min_elapsed); + PrintSpeed("MatMul", A_extents, B_extents, min_elapsed); } - AssertClose(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), - c_slow->data(), c.get()); + AssertClose(A, B, C_slow, C); } void TestAllMatMul() { @@ -264,8 +257,9 @@ void TestAllMatMul() { return; } - NestedPools pools(4, /*pin=*/1); - pools.StartSpinning(); + NestedPools pools(4, /*pin=*/Tristate::kDefault); + Tristate use_spinning = Tristate::kDefault; + pools.MaybeStartSpinning(use_spinning); Allocator::Init(pools.Topology()); MatMulEnv env(pools); @@ -273,52 +267,54 @@ void TestAllMatMul() { using SFP = SfpStream; // large-scale test: batch_size=128 is better than 64 or 256 for SKX. - TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env); - TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env); + // TestMatMul(128, 24576, 3072, /*add=*/false, env); + // TestMatMul(128, 3072, 24576, /*add=*/false, env); + TestMatMul(1, 24576, 3072, /*add=*/false, env); + TestMatMul(1, 3072, 24576, /*add=*/false, env); + TestMatMul(1, 24576, 3072, /*add=*/false, env); + TestMatMul(1, 3072, 24576, /*add=*/false, env); // medium-sized square test - temporarily disabled for faster testing. if constexpr (false) { - TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env); - TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env); - TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env); - TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env); - TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env); + TestMatMul(512, 512, 512, /*add=*/false, env); + TestMatMul(512, 512, 512, /*add=*/true, env); + TestMatMul(512, 512, 512, /*add=*/false, env); + TestMatMul(512, 512, 512, /*add=*/true, env); + TestMatMul(512, 512, 512, /*add=*/false, env); + TestMatMul(512, 512, 512, /*add=*/true, env); } // minimal non-square test. kColsARowsB must be at least 2 vectors. - TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env); - TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env); - TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env); - TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env); - TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env); - TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env); - TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env); - TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env); - TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env); - TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env); - TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env); - TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env); - TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env); - TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env); - TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env); - TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env); - TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env); - TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env); - TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env); - TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env); - TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env); - TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env); - TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env); - TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env); - TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env); - TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env); - TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env); - TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env); + TestMatMul(35, 128, 32, /*add=*/false, env); + TestMatMul(34, 128, 32, /*add=*/true, env); + TestMatMul(33, 128, 32, /*add=*/false, env); + TestMatMul(33, 128, 32, /*add=*/true, env); + TestMatMul(31, 128, 32, /*add=*/false, env); + TestMatMul(29, 128, 32, /*add=*/true, env); + TestMatMul(4, 128, 32, /*add=*/true, env); + TestMatMul(4, 128, 32, /*add=*/false, env); + TestMatMul(4, 128, 32, /*add=*/true, env); + TestMatMul(4, 128, 32, /*add=*/false, env); + TestMatMul(4, 128, 32, /*add=*/true, env); + TestMatMul(4, 128, 32, /*add=*/false, env); + TestMatMul(3, 128, 32, /*add=*/false, env); + TestMatMul(3, 128, 32, /*add=*/true, env); + TestMatMul(3, 128, 32, /*add=*/false, env); + TestMatMul(3, 128, 32, /*add=*/true, env); + TestMatMul(3, 128, 32, /*add=*/false, env); + TestMatMul(3, 128, 32, /*add=*/true, env); + TestMatMul(2, 128, 64, /*add=*/true, env); + TestMatMul(2, 128, 64, /*add=*/false, env); + TestMatMul(2, 128, 64, /*add=*/true, env); + TestMatMul(2, 128, 64, /*add=*/false, env); + TestMatMul(2, 128, 64, /*add=*/true, env); + TestMatMul(2, 128, 64, /*add=*/false, env); + TestMatMul(1, 128, 32, /*add=*/false, env); + TestMatMul(1, 128, 32, /*add=*/true, env); + TestMatMul(1, 128, 32, /*add=*/false, env); + TestMatMul(1, 128, 32, /*add=*/true, env); + TestMatMul(1, 128, 32, /*add=*/false, env); + TestMatMul(1, 128, 32, /*add=*/true, env); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/ops_test.cc b/ops/ops_test.cc index a6f9b2d..93b8f31 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -389,7 +389,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( void TestRopeAndMulBy() { ModelConfig config = ConfigFromModel(Model::GEMMA2_9B); int dim_qkv = config.layer_configs[0].qkv_dim; - RowVectorBatch x(1, dim_qkv); + RowVectorBatch x(Extents2D(1, dim_qkv)); std::mt19937 gen; gen.seed(0x12345678); diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index a320cda..b820eec 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -14,7 +14,6 @@ // limitations under the License. #include -#include #include #include @@ -45,20 +44,20 @@ class PaliGemmaTest : public ::testing::Test { std::string GemmaReply(const std::string& prompt_text) const; void TestQuestions(const char* kQA[][2], size_t num_questions); - std::unique_ptr image_tokens_; + ImageTokens image_tokens_; }; void PaliGemmaTest::InitVit(const std::string& path) { ASSERT_NE(s_env->GetModel(), nullptr); Gemma& model = *(s_env->GetModel()); - image_tokens_ = std::make_unique( - model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim); + image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, + model.GetModelConfig().model_dim)); Image image; HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path)); image.Resize(); RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()}; - model.GenerateImageTokens(runtime_config, image, *image_tokens_); + model.GenerateImageTokens(runtime_config, image, image_tokens_); } std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ @@ -67,7 +66,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ RuntimeConfig runtime_config = {.max_generated_tokens = 512, .verbosity = 0, .gen = &s_env->MutableGen()}; - runtime_config.image_tokens = image_tokens_.get(); + runtime_config.image_tokens = &image_tokens_; size_t abs_pos = 0; std::string mutable_prompt = prompt_text; std::vector tokens = s_env->WrapAndTokenize(mutable_prompt); @@ -79,7 +78,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ return true; }; runtime_config.stream_token = stream_token, - tokens.insert(tokens.begin(), image_tokens_->BatchSize(), 0); + tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0); size_t num_tokens = tokens.size(); size_t prefix_end = num_tokens; runtime_config.prefill_tbatch_size = num_tokens; diff --git a/util/allocator.cc b/util/allocator.cc index a7d2352..dd9943b 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -162,20 +162,19 @@ static void BindMemory(void* ptr, size_t bytes, size_t node) { static void BindMemory(void*, size_t, size_t) {} #endif // GEMMA_NUMA && HWY_OS_LINUX -void BindTensor(NestedPools& nested, size_t rows, size_t cols, +void BindTensor(NestedPools& nested, const Extents2D& extents, size_t bytes_per_col, void* ptr) { if (!Allocator::UseNUMA()) return; uint8_t* p8 = static_cast(ptr); - const size_t bytes_per_row = cols * bytes_per_col; + const size_t bytes_per_row = extents.cols * bytes_per_col; StaticPartitionRowsAndCols( - nested, rows, cols, bytes_per_col, - [&](size_t node, hwy::ThreadPool&, const size_t /*worker_offset*/, - const size_t row_begin, const size_t row_end, const size_t col_begin, - const size_t col_end) { - for (size_t row = row_begin; row < row_end; ++row) { - uint8_t* slice = p8 + row * bytes_per_row + col_begin * bytes_per_col; - const size_t slice_size = (col_end - col_begin) * bytes_per_col; - BindMemory(slice, slice_size, node); + nested, extents, bytes_per_col, + [&](const Range2D& r, const TaskLocation& loc) { + for (size_t row : r.rows) { + uint8_t* slice = + p8 + row * bytes_per_row + r.cols.begin() * bytes_per_col; + const size_t slice_size = r.cols.Num() * bytes_per_col; + BindMemory(slice, slice_size, loc.node); } }); } diff --git a/util/allocator.h b/util/allocator.h index ca1df45..08476e3 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -22,6 +22,7 @@ #include // std::aligned_alloc / _aligned_malloc // IWYU pragma: begin_exports +#include "util/basics.h" #include "util/threading.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -52,49 +53,6 @@ ByteStorageT AllocateSizeof() { return hwy::AllocateAligned(sizeof(T)); } -// Owns dynamically-allocated aligned memory for a batch of row vectors. -// This can be seen as a (batch_size x len) matrix. -template -class RowVectorBatch { - public: - // Default ctor for Activations ctor. - RowVectorBatch() : batch_size_(0), len_(0) {} - // Main ctor, called from Activations::Allocate. - RowVectorBatch(size_t batch_size, size_t len) - : batch_size_(batch_size), len_(len) { - mem_ = hwy::AllocateAligned(batch_size * len); - } - - // Move-only - RowVectorBatch(RowVectorBatch&) noexcept = delete; - RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; - RowVectorBatch(RowVectorBatch&&) noexcept = default; - RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; - - size_t BatchSize() const { return batch_size_; } - size_t Len() const { return len_; } - - // Returns the given row vector of length `Len()`. - T* Batch(size_t batch_idx) { - HWY_DASSERT(batch_idx < batch_size_); - return mem_.get() + batch_idx * len_; - } - const T* Batch(size_t batch_idx) const { - HWY_DASSERT(batch_idx < batch_size_); - return mem_.get() + batch_idx * len_; - } - - // For MatMul or other operations that process the entire batch at once. - T* All() { return mem_.get(); } - const T* Const() const { return mem_.get(); } - size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); } - - private: - hwy::AlignedFreeUniquePtr mem_; - size_t batch_size_; // rows in the matrix - size_t len_; // columns in the matrix = vector length -}; - // Stateful in order to know whether to bind to NUMA nodes. `Monostate` for // convenience - avoids passing around a reference. class Allocator { @@ -167,10 +125,24 @@ class Allocator { static size_t alignment_; }; +// For shorter arguments to the StaticPartitionRowsAndCols functor. +struct TaskLocation { + TaskLocation(size_t node, size_t package_idx, hwy::ThreadPool& cluster, + size_t worker_offset) + : node(node), + package_idx(package_idx), + cluster(cluster), + worker_offset(worker_offset) {} + size_t node; + size_t package_idx; + hwy::ThreadPool& cluster; + const size_t worker_offset; +}; + // Used in MatMul and allocator.h. Defined here because it depends on // Allocator::Alignment(). template -void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols, +void StaticPartitionRowsAndCols(NestedPools& nested, Extents2D extents, size_t bytes_per_element, const Func& func) { // Both rows and cols must be a multiple of the alignment to avoid // touching remote pages. @@ -183,14 +155,15 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols, hwy::ThreadPool& all_packages = nested.AllPackages(); const size_t num_packages = all_packages.NumWorkers(); const size_t cols_per_package = - hwy::RoundUpTo(hwy::DivCeil(cols, num_packages), multiple); - const size_t col_tasks = hwy::DivCeil(cols, cols_per_package); + hwy::RoundUpTo(hwy::DivCeil(extents.cols, num_packages), multiple); + const size_t col_tasks = hwy::DivCeil(extents.cols, cols_per_package); HWY_ASSERT(col_tasks <= num_packages); all_packages.Run( 0, col_tasks, [&](uint64_t package_idx, size_t package_thread) { HWY_ASSERT(package_idx == package_thread); // one task per worker const size_t col_begin = package_idx * cols_per_package; - const size_t col_end = HWY_MIN(col_begin + cols_per_package, cols); + const Range1D col_range = + MakeRange1D(col_begin, extents.cols, cols_per_package); // Static partitioning of rows across the package's clusters. We assume // that row sharding is cheaper. In MatMul, results can indeed be @@ -198,8 +171,8 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols, hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx); const size_t num_clusters = all_clusters.NumWorkers(); const size_t rows_per_cluster = - hwy::RoundUpTo(hwy::DivCeil(rows, num_clusters), multiple); - const size_t row_tasks = hwy::DivCeil(rows, rows_per_cluster); + hwy::RoundUpTo(hwy::DivCeil(extents.rows, num_clusters), multiple); + const size_t row_tasks = hwy::DivCeil(extents.rows, rows_per_cluster); HWY_ASSERT(row_tasks <= num_clusters); all_clusters.Run( 0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) { @@ -217,11 +190,11 @@ void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols, nested.WorkerOffset(package_idx, cluster_idx); const size_t row_begin = cluster_idx * rows_per_cluster; - const size_t row_end = - HWY_MIN(row_begin + rows_per_cluster, rows); + const Range1D row_range = + MakeRange1D(row_begin, extents.rows, rows_per_cluster); - func(node, cluster, worker_offset, row_begin, row_end, col_begin, - col_end); + func(Range2D(row_range, col_range), + TaskLocation(node, package_idx, cluster, worker_offset)); }); }); } diff --git a/util/app.h b/util/app.h index bf7dc27..ebc16b9 100644 --- a/util/app.h +++ b/util/app.h @@ -28,6 +28,7 @@ #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma #include "util/args.h" +#include "util/basics.h" // Tristate #include "util/threading.h" #include "hwy/base.h" // HWY_IS_ASAN @@ -59,7 +60,9 @@ class AppArgs : public ArgsBase { int verbosity; size_t max_threads; // divided among the detected clusters - int pin; // -1 = auto, 0 = no, 1 = yes + Tristate pin; // pin threads? + Tristate spin; // use spin waits? + // For BoundedSlice: size_t skip_packages; size_t max_packages; @@ -81,7 +84,10 @@ class AppArgs : public ArgsBase { // The exact meaning is more subtle: see the comment at NestedPools ctor. visitor(max_threads, "num_threads", size_t{0}, "Maximum number of threads to use; default 0 = unlimited.", 2); - visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); + visitor(pin, "pin", Tristate::kDefault, + "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); + visitor(spin, "spin", Tristate::kDefault, + "Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2); // These can be used to partition CPU sockets/packages and their // clusters/CCXs across several program instances. The default is to use // all available resources. diff --git a/util/args.h b/util/args.h index 98c8e5c..ab496ae 100644 --- a/util/args.h +++ b/util/args.h @@ -24,6 +24,7 @@ #include #include "compression/io.h" +#include "util/basics.h" // Tristate #include "hwy/base.h" // HWY_ABORT namespace gcpp { @@ -62,6 +63,13 @@ class ArgsBase { } } + void operator()(const Tristate& t, const char* name, + const Tristate& /*init*/, const char* /*help*/, + int print_verbosity = 0) const { + if (verbosity_ >= print_verbosity) { + fprintf(stderr, "%-30s: %s\n", name, ToString(t)); + } + } void operator()(const std::string& t, const char* name, const std::string& /*init*/, const char* /*help*/, int print_verbosity = 0) const { @@ -127,13 +135,33 @@ class ArgsBase { return true; } - static bool SetValue(const char* string, bool& t) { + // Returns lower-cased string. Arg names are expected to be ASCII-only. + static std::string ToLower(const char* string) { std::string value(string); - // Lower-case. Arg names are expected to be ASCII-only. std::transform(value.begin(), value.end(), value.begin(), [](char c) { return 'A' <= c && c <= 'Z' ? c - ('Z' - 'z') : c; }); + return value; + } + static bool SetValue(const char* string, Tristate& t) { + const std::string value = ToLower(string); + if (value == "true" || value == "on" || value == "1") { + t = Tristate::kTrue; + return true; + } else if (value == "false" || value == "off" || value == "0") { + t = Tristate::kFalse; + return true; + } else if (value == "default" || value == "auto" || value == "-1") { + t = Tristate::kDefault; + return true; + } else { + return false; + } + } + + static bool SetValue(const char* string, bool& t) { + const std::string value = ToLower(string); if (value == "true" || value == "on" || value == "1") { t = true; return true; diff --git a/util/basics.h b/util/basics.h index 3aee649..cfe2204 100644 --- a/util/basics.h +++ b/util/basics.h @@ -20,7 +20,8 @@ #include #include -#include "hwy/base.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // HWY_IS_MSAN // IWYU pragma: end_exports #if HWY_IS_MSAN @@ -29,6 +30,19 @@ namespace gcpp { +enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; + +static inline const char* ToString(Tristate t) { + switch (t) { + case Tristate::kFalse: + return "false"; + case Tristate::kTrue: + return "true"; + case Tristate::kDefault: + return "default"; + } +} + using BF16 = hwy::bfloat16_t; static inline void MaybeCheckInitialized(const void* ptr, size_t size) { @@ -46,6 +60,195 @@ struct TokenAndProb { float prob; }; +// Entire size of a 2D array. By contrast, Range2D is a subrange. +struct Extents2D { + Extents2D() : rows(0), cols(0) {} + Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) { + HWY_DASSERT(rows != 0); + HWY_DASSERT(cols != 0); + } + + size_t Area() const { return rows * cols; } + + size_t rows; + size_t cols; +}; + +// Range2D consists of two Range1D. +struct Range1D { + Range1D(size_t begin, size_t end) : begin_(begin), end_(end) { + HWY_DASSERT(begin < end); + } + size_t Num() const { return end_ - begin_; } + + // Enable range-based for loops. + class Iterator { + public: + Iterator(size_t i) : i_(i) {} + + Iterator& operator++() { + ++i_; + return *this; + } + bool operator!=(const Iterator& other) const { return i_ != other.i_; } + size_t operator*() const { return i_; } + // Enable using begin() directly as a size_t. + operator size_t() const { return i_; } + + private: + size_t i_; + }; + Iterator begin() const { return Iterator(begin_); } + Iterator end() const { return Iterator(end_); } + + const size_t begin_; + const size_t end_; +}; + +static inline Range1D MakeRange1D(size_t begin, size_t end, size_t max_size) { + return Range1D(begin, HWY_MIN(begin + max_size, end)); +} + +// In MatMul, the two axes are used independently, hence we do not define +// Range2D as a top-left and extents. +struct Range2D { + Range2D(Range1D rows, Range1D cols) : rows(rows), cols(cols) {} + const Range1D rows; + const Range1D cols; +}; + +// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because +// it is always float and does not support compressed T, but does support an +// arbitrary stride >= cols. +template +class RowPtr { + public: + RowPtr(T* HWY_RESTRICT row0, size_t cols) + : row0_(row0), cols_(cols), stride_(cols) {} + + T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } + size_t Cols() const { return cols_; } + + size_t Stride() const { return stride_; } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + } + + private: + T* HWY_RESTRICT row0_; + size_t stride_; + size_t cols_; +}; + +using RowPtrF = RowPtr; + +// Owns dynamically-allocated aligned memory for a batch of row vectors. +// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns +// the memory. +template +class RowVectorBatch { + public: + // Default ctor for Activations ctor. + RowVectorBatch() = default; + // Main ctor, called from Activations::Allocate. + RowVectorBatch(Extents2D extents) : extents_(extents) { + mem_ = hwy::AllocateAligned(extents_.rows * extents_.cols); + } + + // Move-only + RowVectorBatch(RowVectorBatch&) noexcept = delete; + RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; + RowVectorBatch(RowVectorBatch&&) noexcept = default; + RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; + + size_t BatchSize() const { return extents_.rows; } + size_t Cols() const { return extents_.cols; } + Extents2D Extents() const { return extents_; } + + // Returns the given row vector of length `Cols()`. + T* Batch(size_t batch_idx) { + HWY_DASSERT(batch_idx < BatchSize()); + return mem_.get() + batch_idx * Cols(); + } + const T* Batch(size_t batch_idx) const { + HWY_DASSERT(batch_idx < BatchSize()); + return mem_.get() + batch_idx * Cols(); + } + + // For MatMul or other operations that process the entire batch at once. + // TODO: remove once we only use Mat. + T* All() { return mem_.get(); } + const T* Const() const { return mem_.get(); } + size_t NumBytes() const { return BatchSize() * Cols() * sizeof(T); } + + private: + hwy::AlignedFreeUniquePtr mem_; + Extents2D extents_; +}; + +// Used for the A and B arguments of `MatMul`, which are always const. +// Create via MakeConstMat. This differs from `RowPtr` in that it supports the +// `ofs` required for compressed T. +template +struct ConstMat { + ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0) + : ptr(ptr), extents(extents), ofs(ofs) { + HWY_DASSERT(ptr != nullptr); + } + // TODO: support stride for page alignment. + size_t Row(size_t r) const { + if constexpr (HWY_IS_DEBUG_BUILD) { + if (r >= extents.rows) { + HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows); + } + } + return ofs + extents.cols * r; + } + + const Extents2D& Extents() const { return extents; } + + // Shrinks the row-extent of this matrix view, i.e. reduces the view to a + // subrange of the original rows starting at row 0. + void ShrinkRows(size_t rows) { + HWY_ASSERT(rows <= extents.rows); + extents.rows = rows; + } + + const T* HWY_RESTRICT ptr; + Extents2D extents; + + // `scale` allows expanding the smaller range of `SfpStream` to the original + // values. MatFromWeights sets this from `MatPtr`. + float scale = 1.0f; + + // Offset to add to `ptr`; separate because T=NuqStream does not support + // pointer arithmetic. + size_t ofs; +}; + +// For deducing T. +template +ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, + size_t ofs = 0) { + return ConstMat(ptr, extents, ofs); +} + +// For A argument to MatMul (activations). +template +ConstMat ConstMatFromBatch(size_t batch_size, + const RowVectorBatch& row_vectors) { + HWY_DASSERT(batch_size <= row_vectors.BatchSize()); + return MakeConstMat(const_cast(row_vectors.Const()), + Extents2D(batch_size, row_vectors.Cols())); +} + +// For C argument to MatMul. +template +RowPtr RowPtrFromBatch(RowVectorBatch& row_vectors) { + return RowPtr(row_vectors.All(), row_vectors.Cols()); +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ diff --git a/util/threading.cc b/util/threading.cc new file mode 100644 index 0000000..b4bb84e --- /dev/null +++ b/util/threading.cc @@ -0,0 +1,400 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "util/threading.h" + +#include + +#include // std::sort +#include +#include // std::make_unique +#include // std::move +#include + +// Placeholder for container detection, do not remove +#include "util/basics.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/contrib/thread_pool/topology.h" + +namespace gcpp { + +// Sort T := packages/clusters by descending 'size' so that users who only use +// one Group get the largest. +template +static void SortByDescendingSize(std::vector& groups) { + std::sort(groups.begin(), groups.end(), + [](const T& a, const T& b) { return a.Size() > b.Size(); }); +} + +BoundedTopology::BoundedTopology(BoundedSlice package_slice, + BoundedSlice cluster_slice, + BoundedSlice lp_slice) { + // Regardless of topology, ignore LPs disabled via OS, taskset, or numactl. + LPS enabled_lps; + if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) { + const size_t num_lps = hwy::TotalLogicalProcessors(); + fprintf(stderr, + "Warning, unknown OS affinity, considering all %zu LPs enabled\n.", + num_lps); + for (size_t lp = 0; lp < num_lps; ++lp) { + enabled_lps.Set(lp); + } + } + + // Without threading support, only keep the first enabled LP; it might still + // make sense to pin the main thread to avoid migrations. + if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) { + HWY_ASSERT(enabled_lps.Any()); + const size_t lp = enabled_lps.First(); + enabled_lps = LPS(); + enabled_lps.Set(lp); + fprintf(stderr, + "Warning, threads not supported, using only the main thread\n."); + } + +#if !GEMMA_DISABLE_TOPOLOGY + if (HWY_LIKELY(!topology_.packages.empty())) { + InitFromTopology(enabled_lps, package_slice, cluster_slice); + } +#endif + + // Topology unknown or no packages with enabled LPs: create a single + // package with one cluster, and one node. + if (HWY_UNLIKELY(NumPackages() == 0)) { + InitFromSlice(enabled_lps, lp_slice); + } + + HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0); +} + +// Topology is unknown, rely on OS affinity and user-specified slice. +BoundedTopology::Cluster::Cluster(const LPS& enabled_lps, + BoundedSlice lp_slice) { + // Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so + // we honor both the OS affinity and the user-specified slice. Note that + // this can be used to exclude hyperthreads because Linux groups LPs by + // sibling index. For example, the first `num_cores` are not siblings. + const size_t detected = enabled_lps.Count(); + size_t enabled_idx = 0; + enabled_lps.Foreach([&](size_t lp) { + if (lp_slice.Contains(detected, enabled_idx++)) { + AddLP(lp); + } + }); + + // lp_slice can only reduce the number of `enabled_lps`, and not below 1. + HWY_ASSERT(num_workers_ != 0); +} + +BoundedTopology::Cluster::Cluster(const LPS& enabled_lps, + const std::vector& all_lps, + const hwy::Topology::Cluster& tcluster) { + bool is_first_lp = true; + + tcluster.lps.Foreach([&](size_t lp) { + // Skip if not first-hyperthread or disabled. + if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return; + + AddLP(lp); + + // Set `node` once, and ensure subsequent nodes match - we assume there + // is only one NUMA node per cluster. + const size_t lp_node = static_cast(all_lps[lp].node); + if (is_first_lp) { + is_first_lp = false; + node_ = lp_node; + } else { + static bool warned = false; + if (lp_node != node_ && !warned) { + warned = true; + fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n", + lp, lp_node, node_); + } + } + }); +} + +// NOTE: caller is responsible for checking whether `clusters` is empty. +BoundedTopology::Package::Package(const LPS& enabled_lps, + const hwy::Topology& topology, + size_t package_idx, + BoundedSlice cluster_slice) { + const hwy::Topology::Package& tpackage = topology.packages[package_idx]; + // Populate `clusters` with the subset of clusters in `cluster_slice` that + // have any enabled LPs. If `clusters` remains empty, the caller will + // skip this `Package`. + clusters.reserve(cluster_slice.Num(tpackage.clusters.size())); + cluster_slice.Foreach( + "cluster", tpackage.clusters.size(), [&](size_t cluster_idx) { + const hwy::Topology::Cluster& tcluster = tpackage.clusters[cluster_idx]; + Cluster cluster(enabled_lps, topology.lps, tcluster); + // Skip if empty, i.e. too few `enabled_lps`. + if (HWY_LIKELY(cluster.Size() != 0)) { + clusters.push_back(std::move(cluster)); + } + }); + SortByDescendingSize(clusters); +} + +#if !GEMMA_DISABLE_TOPOLOGY + +static size_t CoresFromLPs(const LPS& lps, const hwy::Topology& topology) { + LPS cores; + lps.Foreach([&](size_t lp) { + if (topology.lps[lp].smt == 0) cores.Set(lp); + }); + return cores.Count(); +} + +// Scans hwy::Topology for clusters and their size, for use by topology_string_. +static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters, + size_t& max_tcluster_cores, + size_t& max_tcluster_lps) { + max_tclusters = 0; + max_tcluster_cores = 0; + max_tcluster_lps = 0; + for (size_t package_idx = 0; package_idx < topology_.packages.size(); + ++package_idx) { + const std::vector& tclusters = + topology_.packages[package_idx].clusters; + max_tclusters = HWY_MAX(max_tclusters, tclusters.size()); + size_t tcluster_cores = 0; + size_t tcluster_lps = 0; + for (size_t cluster_idx = 0; cluster_idx < tclusters.size(); + ++cluster_idx) { + const size_t cores = CoresFromLPs(tclusters[cluster_idx].lps, topology_); + const size_t lps = tclusters[cluster_idx].lps.Count(); + tcluster_cores = HWY_MAX(tcluster_cores, cores); + tcluster_lps = HWY_MAX(tcluster_lps, lps); + } + + if (tclusters.size() > 1 && tcluster_cores > 8) { + fprintf(stderr, + "Package %zu: multiple clusters with max size %zu, whereas CCX " + "only have 8, may indicate a bug in hwy::Topology.\n", + package_idx, tcluster_cores); + } + max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores); + max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps); + } + HWY_ASSERT(max_tclusters != 0); + HWY_ASSERT(max_tcluster_cores != 0); + HWY_ASSERT(max_tcluster_lps >= max_tcluster_cores); +} + +// Main part of ctor, called when topology is known. +void BoundedTopology::InitFromTopology(const LPS& enabled_lps, + BoundedSlice package_slice, + BoundedSlice cluster_slice) { + size_t max_tclusters, max_tcluster_cores, max_tcluster_lps; + ScanTClusters(topology_, max_tclusters, max_tcluster_cores, max_tcluster_lps); + + // (Possibly empty) subset of `Topology` packages that have `enabled_lps`. + package_slice.Foreach( + "package", topology_.packages.size(), [&](size_t package_idx) { + Package package(enabled_lps, topology_, package_idx, cluster_slice); + // Skip if empty, i.e. too few `enabled_lps`. + if (HWY_LIKELY(!package.clusters.empty())) { + packages_.push_back(std::move(package)); + } + }); + if (NumPackages() == 0) return; + SortByDescendingSize(packages_); + + // Remember NUMA nodes that we are actually using (not just enabled). + for (const Package& p : packages_) { + for (const Cluster& c : p.clusters) { + nodes_.Set(c.Node()); + } + } + + // Scan for max BoundedTopology clusters and their size, for topology_string_. + size_t all_max_cluster_size = 0; + for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) { + size_t max_cluster_size = 0; + for (size_t cluster_idx = 0; cluster_idx < NumClusters(package_idx); + ++cluster_idx) { + max_cluster_size = HWY_MAX(max_cluster_size, + GetCluster(package_idx, cluster_idx).Size()); + } + if (NumClusters(package_idx) > 1 && max_cluster_size > 8) { + fprintf(stderr, + "Package %zu: multiple clusters with max size %zu, whereas CCX " + "only have 8, may indicate a bug in BoundedTopology.\n", + package_idx, max_cluster_size); + } + all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size); + } + + snprintf(topology_string_, sizeof(topology_string_), + "%zuS %zuX %zuC %zuH, using %zuS %zuX %zuC (nodes=%zu)", + topology_.packages.size(), max_tclusters, max_tcluster_cores, + max_tcluster_lps / max_tcluster_cores, packages_.size(), + NumClusters(0), all_max_cluster_size, nodes_.Count()); +} + +#endif // !GEMMA_DISABLE_TOPOLOGY + +void BoundedTopology::InitFromSlice(const LPS& enabled_lps, + BoundedSlice lp_slice) { + packages_.push_back(Package(enabled_lps, lp_slice)); + + snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", + GetCluster(0, 0).Size()); + + // Assume a single NUMA node. + nodes_.Set(0); + HWY_ASSERT(NumNodes() == 1); +} + +static PoolPtr MakePool(size_t num_workers) { + // `ThreadPool` expects the number of threads to create, which is one less + // than the number of workers, but avoid underflow if zero. + const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1; + return std::make_unique(num_threads); +} + +static bool InContainer() { + return false;} + +class NestedPools::Pinning { + public: + Pinning(Tristate pin, const BoundedTopology& topology) { + if (pin == Tristate::kDefault) { + // Pinning is unreliable inside containers because the hypervisor might + // periodically change our affinity mask, or other processes might also + // pin themselves to the same LPs. + pin = InContainer() ? Tristate::kFalse : Tristate::kTrue; + } + want_pin_ = (pin == Tristate::kTrue); + } + + // If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`, + // and sets `any_error_` if any fails. + void MaybePin(const BoundedTopology::Cluster& cluster, PoolPtr& pool) { + if (HWY_UNLIKELY(!want_pin_)) return; + + const std::vector lps = cluster.LPVector(); + HWY_ASSERT(pool->NumWorkers() <= lps.size()); + pool->Run( + 0, pool->NumWorkers(), + [this, &pool, &lps](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each worker has one task + if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { + fprintf(stderr, + "Pinning failed for task %zu of %zu to %zu (size %zu)\n", + task, pool->NumWorkers(), lps[task], lps.size()); + (void)any_error_.test_and_set(); + } + }); + } + + bool WantPin() const { return want_pin_; } + + // Called ONCE after all MaybePin because it invalidates the error status. + bool AllPinned() { + // If !want_pin_, MaybePin will return without setting any_error_, but in + // that case we still want to return false to avoid spinning. + // .test() was only added in C++20, so we use .test_and_set() instead. + return want_pin_ && !any_error_.test_and_set(); + } + + private: + std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; + bool want_pin_; // set in ctor +}; // Pinning + +// Used to divide max_threads and max_workers_per_package across packages and +// clusters. Ensures small upper bounds are respected. +static size_t DivideMaxAcross(const size_t max, const size_t instances) { + // No limit. + if (max == 0) return 0; + // We have enough to distribute. + if (max >= instances) return max / instances; + // Use max as the upper bound for each instance because division would return + // zero, which means 'unlimited'. + return max; +} + +NestedPools::NestedPools(size_t max_threads, Tristate pin, + BoundedSlice package_slice, BoundedSlice cluster_slice, + BoundedSlice lp_slice) + : topology_(package_slice, cluster_slice, lp_slice) { + Pinning pinning(pin, topology_); + packages_.resize(topology_.NumPackages()); + all_packages_ = MakePool(packages_.size()); + const size_t max_workers_per_package = + DivideMaxAcross(max_threads, packages_.size()); + // Each worker in all_packages_, including the main thread, will be the + // calling thread of an all_clusters->Run, and hence pinned to one of the + // `cluster.lps` if `pin`. + all_packages_->Run( + 0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) { + HWY_ASSERT(package_idx == thread); // each thread has one task + packages_[package_idx] = Package( + topology_, package_idx, max_workers_per_package, pinning, lp_slice); + }); + + all_pinned_ = pinning.AllPinned(); + pin_string_ = all_pinned_ ? "pinned" + : pinning.WantPin() ? "pinning failed" + : "pinning skipped"; + + // For mapping package/cluster/thread to noncontiguous TLS indices, in case + // cluster/thread counts differ. + HWY_ASSERT(!packages_.empty() && packages_.size() <= 16); + for (const Package& p : packages_) { + max_clusters_per_package_ = + HWY_MAX(max_clusters_per_package_, p.NumClusters()); + max_workers_per_cluster_ = + HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster()); + } + HWY_ASSERT(max_clusters_per_package_ >= 1); + HWY_ASSERT(max_clusters_per_package_ <= 64); + HWY_ASSERT(max_workers_per_cluster_ >= 1); + HWY_ASSERT(max_workers_per_cluster_ <= 256); +} + +// `max_or_zero` == 0 means no limit. +static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { + return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero); +} + +NestedPools::Package::Package(const BoundedTopology& topology, + size_t package_idx, + size_t max_workers_per_package, Pinning& pinning, + BoundedSlice lp_slice) { + // Pre-allocate because elements are set concurrently. + clusters_.resize(topology.NumClusters(package_idx)); + const size_t max_workers_per_cluster = + DivideMaxAcross(max_workers_per_package, clusters_.size()); + + all_clusters_ = MakePool(clusters_.size()); + // Parallel so we also pin the calling worker in `all_clusters` to + // `cluster.lps`. + all_clusters_->Run( + 0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) { + HWY_ASSERT(cluster_idx == thread); // each thread has one task + const BoundedTopology::Cluster& cluster = + topology.GetCluster(package_idx, cluster_idx); + clusters_[cluster_idx] = + MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster)); + // Pin workers AND the calling thread from `all_clusters`. + pinning.MaybePin(cluster, clusters_[cluster_idx]); + }); +} + +} // namespace gcpp diff --git a/util/threading.h b/util/threading.h index 6dbf806..6be1503 100644 --- a/util/threading.h +++ b/util/threading.h @@ -17,14 +17,12 @@ #define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #include -#include -#include // std::sort -#include // std::unique_ptr -#include // std::move +#include // std::unique_ptr #include -#include "hwy/base.h" // HWY_ASSERT +#include "util/basics.h" // Tristate +#include "hwy/base.h" // HWY_ASSERT #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/topology.h" @@ -78,6 +76,10 @@ class BoundedSlice { // "LP" is a logical processor, a 0-based index passed to the OS. using LPS = hwy::LogicalProcessorSet; +// We want vectors of hwy::ThreadPool, which is unfortunately not movable, +// hence we wrap them in unique_ptr. +using PoolPtr = std::unique_ptr; + // Wraps hwy::Topology and only keeps the subset of packages and clusters // apportioned by BoundedSlice, further limited by the OS affinity mask. // NOTE: if topology is unknown or the OS affinity is too restrictive, we fall @@ -85,96 +87,18 @@ using LPS = hwy::LogicalProcessorSet; class BoundedTopology { public: BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, - BoundedSlice lp_slice) { - // Regardless of topology, ignore LPs disabled via OS, taskset, or numactl. - LPS enabled_lps; - if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) { - const size_t num_lps = hwy::TotalLogicalProcessors(); - fprintf( - stderr, - "Warning, unknown OS affinity, considering all %zu LPs enabled\n.", - num_lps); - for (size_t lp = 0; lp < num_lps; ++lp) { - enabled_lps.Set(lp); - } - } - - // Without threading support, only keep the first enabled LP; it might still - // make sense to pin the main thread. - if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) { - HWY_ASSERT(enabled_lps.Any()); - const size_t lp = enabled_lps.First(); - enabled_lps = LPS(); - enabled_lps.Set(lp); - } - -#if !GEMMA_DISABLE_TOPOLOGY - if (HWY_LIKELY(!topology_.packages.empty())) { - InitFromTopology(enabled_lps, package_slice, cluster_slice); - } -#endif - - // Topology unknown, disabled or no packages with enabled LPs: create a - // single package with one cluster, and one node. - if (HWY_UNLIKELY(NumPackages() == 0)) { - InitFromSlice(enabled_lps, lp_slice); - } - - HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0); - } + BoundedSlice lp_slice); size_t NumPackages() const { return packages_.size(); } - const char* TopologyString() const { return topology_string_; } size_t NumNodes() const { return nodes_.Count(); } + const char* TopologyString() const { return topology_string_; } class Cluster { public: - // Topology is unknown, rely on OS affinity and user-specified slice. - Cluster(const LPS& enabled_lps, BoundedSlice lp_slice) { - // Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so - // we honor both the OS affinity and the user-specified slice. Note that - // this can be used to exclude hyperthreads because Linux groups LPs by - // sibling index. For example, the first `num_cores` are not siblings. - const size_t detected = enabled_lps.Count(); - size_t enabled_idx = 0; - enabled_lps.Foreach([&](size_t lp) { - if (lp_slice.Contains(detected, enabled_idx++)) { - AddLP(lp); - } - }); - - // lp_slice can only reduce the number of `enabled_lps`, and not below 1. - HWY_ASSERT(num_workers_ != 0); - } - + Cluster(const LPS& enabled_lps, BoundedSlice lp_slice); Cluster(const LPS& enabled_lps, const std::vector& all_lps, - const hwy::Topology::Cluster& tcluster) { - bool is_first_lp = true; - - tcluster.lps.Foreach([&](size_t lp) { - // Skip if not first-hyperthread or disabled. - if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return; - - AddLP(lp); - - // Set `node` once, and ensure subsequent nodes match - we assume there - // is only one NUMA node per cluster. - const size_t lp_node = static_cast(all_lps[lp].node); - if (is_first_lp) { - is_first_lp = false; - node_ = lp_node; - } else { - static bool warned = false; - if (lp_node != node_ && !warned) { - warned = true; - fprintf(stderr, - "WARNING: lp %zu on node %zu != cluster node %zu.\n", lp, - lp_node, node_); - } - } - }); - } + const hwy::Topology::Cluster& tcluster); // For SortByDescendingSize. size_t Size() const { return num_workers_; } @@ -221,53 +145,15 @@ class BoundedTopology { return package.clusters[cluster_idx]; } - // Returns total number of cluster workers, for deciding whether to pin. - size_t TotalWorkers() const { - size_t total_workers = 0; - for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) { - const size_t num_clusters = NumClusters(package_idx); - for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { - total_workers += GetCluster(package_idx, cluster_idx).Size(); - } - } - return total_workers; - } - private: - // Sort T := packages/clusters by descending 'size' so that users who only use - // one Group get the largest. - template - static void SortByDescendingSize(std::vector& groups) { - std::sort(groups.begin(), groups.end(), - [](const T& a, const T& b) { return a.Size() > b.Size(); }); - } - struct Package { // Topology is unknown, rely on OS affinity and user-specified slice. Package(const LPS& enabled_lps, BoundedSlice lp_slice) { clusters.push_back(Cluster(enabled_lps, lp_slice)); } - // NOTE: caller is responsible for checking whether `clusters` is empty. Package(const LPS& enabled_lps, const hwy::Topology& topology, - size_t package_idx, BoundedSlice cluster_slice) { - const hwy::Topology::Package& tpackage = topology.packages[package_idx]; - // Populate `clusters` with the subset of clusters in `cluster_slice` that - // have any enabled LPs. If `clusters` remains empty, the caller will - // skip this `Package`. - clusters.reserve(cluster_slice.Num(tpackage.clusters.size())); - cluster_slice.Foreach( - "cluster", tpackage.clusters.size(), [&](size_t cluster_idx) { - const hwy::Topology::Cluster& tcluster = - tpackage.clusters[cluster_idx]; - Cluster cluster(enabled_lps, topology.lps, tcluster); - // Skip if empty, i.e. too few `enabled_lps`. - if (HWY_LIKELY(cluster.Size() != 0)) { - clusters.push_back(std::move(cluster)); - } - }); - SortByDescendingSize(clusters); - } + size_t package_idx, BoundedSlice cluster_slice); // For SortByDescendingSize. size_t Size() const { return clusters.size(); } @@ -275,48 +161,9 @@ class BoundedTopology { std::vector clusters; }; // Package -#if !GEMMA_DISABLE_TOPOLOGY - // Main part of ctor, called when topology is known. void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice, - BoundedSlice cluster_slice) { - // (Possibly empty) subset of `Topology` packages that have `enabled_lps`. - package_slice.Foreach( - "package", topology_.packages.size(), [&](size_t package_idx) { - Package package(enabled_lps, topology_, package_idx, cluster_slice); - // Skip if empty, i.e. too few `enabled_lps`. - if (HWY_LIKELY(!package.clusters.empty())) { - packages_.push_back(std::move(package)); - } - }); - if (NumPackages() == 0) return; - SortByDescendingSize(packages_); - - const hwy::Topology::Package& tpackage0 = topology_.packages[0]; - HWY_ASSERT(!tpackage0.clusters.empty()); - const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0]; - // GetCluster(0, 0) is valid because only non-empty Packages were kept. - snprintf(topology_string_, sizeof(topology_string_), - "%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(), - tpackage0.clusters.size(), tcluster0.lps.Count(), packages_.size(), - NumClusters(0), GetCluster(0, 0).Size()); - - // Remember NUMA nodes of *enabled* LPs. - enabled_lps.Foreach([&](size_t lp) { - nodes_.Set(static_cast(topology_.lps[lp].node)); - }); - } -#endif - - void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice) { - packages_.push_back(Package(enabled_lps, lp_slice)); - - snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", - GetCluster(0, 0).Size()); - - // Assume a single NUMA node. - nodes_.Set(0); - HWY_ASSERT(NumNodes() == 1); - } + BoundedSlice cluster_slice); + void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice); #if !GEMMA_DISABLE_TOPOLOGY hwy::Topology topology_; @@ -360,51 +207,32 @@ class NestedPools { // would cause huge slowdowns when spinning, the `BoundedSlice` arguments // only impose upper bounds on the number of detected packages and clusters // rather than defining the actual number of threads. - // - // `pin` is 0 or 1 to force disable/enable, or -1 to choose automatically. - NestedPools(size_t max_threads, int pin = -1, + NestedPools(size_t max_threads, Tristate pin = Tristate::kDefault, BoundedSlice package_slice = BoundedSlice(), BoundedSlice cluster_slice = BoundedSlice(), - BoundedSlice lp_slice = BoundedSlice()) - : topology_(package_slice, cluster_slice, lp_slice) { - if (pin == -1) pin = topology_.TotalWorkers() >= 12; + BoundedSlice lp_slice = BoundedSlice()); - packages_.resize(topology_.NumPackages()); - all_packages_ = MakePool(packages_.size()); - const size_t max_workers_per_package = max_threads / packages_.size(); - // Each worker in all_packages_, including the main thread, will be the - // calling thread of an all_clusters->Run, and hence pinned to one of the - // `cluster.lps` if `pin`. - all_packages_->Run( - 0, all_packages_->NumWorkers(), - [&](uint64_t package_idx, size_t thread) { - HWY_ASSERT(package_idx == thread); // each thread has one task - packages_[package_idx] = Package( - topology_, package_idx, max_workers_per_package, pin, lp_slice); - }); - - // For mapping package/cluster/thread to noncontiguous TLS indices, in case - // cluster/thread counts differ. - HWY_ASSERT(!packages_.empty() && packages_.size() <= 16); - for (const Package& p : packages_) { - max_clusters_per_package_ = - HWY_MAX(max_clusters_per_package_, p.NumClusters()); - max_workers_per_cluster_ = - HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster()); + // Subject to `use_spinning`, enables spin waits with the goal of reducing the + // latency of barrier synchronization. We only spin during Generate to avoid + // wasting energy during long waits. If `use_spinning` is kDefault, we first + // set it to kTrue or kFalse based on a heuristic. + void MaybeStartSpinning(Tristate& use_spinning) { + if (HWY_UNLIKELY(use_spinning == Tristate::kDefault)) { + // The default is to only spin when pinning was enabled and supported by + // the OS. Unless spin-waits have near-exclusive use of a core, the tail + // latency can be higher than blocking waits. + use_spinning = all_pinned_ ? Tristate::kTrue : Tristate::kFalse; + } + if (use_spinning == Tristate::kTrue) { + SetWaitMode(hwy::PoolWaitMode::kSpin); + } + } + void MaybeStopSpinning(const Tristate use_spinning) { + HWY_DASSERT(use_spinning != Tristate::kDefault); // see MaybeStartSpinning + if (use_spinning == Tristate::kTrue) { + SetWaitMode(hwy::PoolWaitMode::kBlock); } - HWY_ASSERT(max_clusters_per_package_ >= 1); - HWY_ASSERT(max_clusters_per_package_ <= 64); - HWY_ASSERT(max_workers_per_cluster_ >= 1); - HWY_ASSERT(max_workers_per_cluster_ <= 256); } - - // Spinning reduces the latency of barrier synchronization, but wastes lots - // of energy for long waits, so only do it during generation. Spinning might - // also be unsafe in virtualized environments because we require threads to - // be running on their own core and thus responsive to the barrier - // synchronization. - void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); } - void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); } hwy::ThreadPool& AllPackages() { return *all_packages_; } hwy::ThreadPool& AllClusters(size_t package_idx) { @@ -435,7 +263,9 @@ class NestedPools { // For Allocator const BoundedTopology& Topology() const { return topology_; } + // For ShowConfig const char* TopologyString() const { return topology_.TopologyString(); } + const char* PinString() const { return pin_string_; } // Returns a single pool on the first package: either one thread per cluster // if there is more than one, which maximizes available memory bandwidth, or @@ -449,56 +279,14 @@ class NestedPools { } private: - // `max_or_zero` == 0 means no limit. - static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { - return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero); - } - - // We want vectors of hwy::ThreadPool, which is unfortunately not movable, - // hence we wrap them in unique_ptr. - using PoolPtr = std::unique_ptr; - - static PoolPtr MakePool(size_t num_workers) { - // `ThreadPool` expects the number of threads to create, which is one less - // than the number of workers, but avoid underflow if zero. - const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1; - return std::make_unique(num_threads); - } + class Pinning; class Package { public: Package() = default; // for vector Package(const BoundedTopology& topology, size_t package_idx, - size_t max_workers_per_package, int pin, BoundedSlice lp_slice) { - // Pre-allocate because elements are set concurrently. - clusters_.resize(topology.NumClusters(package_idx)); - const size_t max_workers_per_cluster = - max_workers_per_package / clusters_.size(); - - all_clusters_ = MakePool(clusters_.size()); - // Parallel so we also pin the calling worker in `all_clusters` to - // `cluster.lps`. - all_clusters_->Run( - 0, all_clusters_->NumWorkers(), - [&](size_t cluster_idx, size_t thread) { - HWY_ASSERT(cluster_idx == thread); // each thread has one task - const BoundedTopology::Cluster& cluster = - topology.GetCluster(package_idx, cluster_idx); - clusters_[cluster_idx] = - MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster)); - if (HWY_LIKELY(pin)) { - // Pin threads AND the calling thread from `all_clusters` to lps. - const std::vector lps = cluster.LPVector(); - HWY_ASSERT(clusters_[cluster_idx]->NumWorkers() <= lps.size()); - clusters_[cluster_idx]->Run( - 0, clusters_[cluster_idx]->NumWorkers(), - [&lps](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task - hwy::PinThreadToLogicalProcessor(lps[task]); - }); - } - }); - } + size_t max_workers_per_package, Pinning& pinning, + BoundedSlice lp_slice); size_t NumClusters() const { return clusters_.size(); } size_t MaxWorkersPerCluster() const { @@ -536,6 +324,8 @@ class NestedPools { } BoundedTopology topology_; + bool all_pinned_; + const char* pin_string_; std::vector packages_; PoolPtr all_packages_;