diff --git a/gemma/activations.h b/gemma/activations.h index 9db3dee..45096cc 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -18,6 +18,7 @@ #include // sqrtf #include +#include #include #include @@ -99,6 +100,7 @@ struct AttentionActivations { 1000000.0)), div_seq_len(static_cast(seq_len)), + div_heads(static_cast(layer_config.heads)), query_scale(ChooseQueryScale(config)) { // Batch size can be 0 in experimental code so do not assert. if (batch_size == 0) { @@ -125,10 +127,6 @@ struct AttentionActivations { att_sums.OverrideRows(batch_size); } - bool IsGlobalLayer(size_t layer_idx) const { - return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor(); - } - const ModelConfig& config; MatStorageT q; // query @@ -144,6 +142,8 @@ struct AttentionActivations { MatStorageT inv_timescale_global; hwy::Divisor div_seq_len; + // Unfortunately, some models (Griffin) have non-power-of-two heads. + hwy::Divisor div_heads; float query_scale; }; diff --git a/gemma/attention.cc b/gemma/attention.cc index 0b9471e..39c75e8 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -52,7 +52,9 @@ namespace HWY_NAMESPACE { static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT q, - const MatPtrT& k, float* HWY_RESTRICT att) { + const MatPtrT& k, float* HWY_RESTRICT att, + const size_t worker) { + PROFILER_ZONE2(worker, "Gen.Attention.QDotK"); if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { @@ -71,7 +73,8 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, static void PositionalEncodingQK(float* qk, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, - const size_t pos, const float mul = 1.0f) { + const size_t worker, const size_t pos, + const float mul = 1.0f) { const size_t qkv_dim = layer.layer_config.qkv_dim; const PostQKType& post_qk = layer.layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. @@ -83,10 +86,10 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx, } // PostQKType::Rope if (post_qk == PostQKType::HalfRope) { - Rope(qk, qkv_dim / 2, inv_timescale, pos); - if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); + Rope(qk, qkv_dim / 2, inv_timescale, pos, worker); + if (mul != 1.0f) MulByConst(mul, qk, qkv_dim, worker); } else { - RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos); + RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, worker); } } @@ -94,39 +97,38 @@ static void PositionalEncodingQK(float* qk, const size_t layer_idx, // `att_out`. Equivalent in gemma/modules.py: // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) // `v` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. -static HWY_INLINE void WeightedSumV(const size_t start_pos, - const size_t last_pos, - const hwy::Divisor& div_seq_len, - const float* HWY_RESTRICT att, - const MatPtrT& v, - float* HWY_RESTRICT att_out) { - const size_t qkv_dim = v.Cols(); - hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); - +static HWY_INLINE void WeightedSumV( + const size_t start_pos, const size_t last_pos, + const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, + const MatPtrT& v, float* HWY_RESTRICT att_out, const size_t worker) { if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { - // Slightly faster: no wraparound. - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols()); + // Slightly faster: no wraparound. Could be replaced with MatMul(att, v) if + // we supported non-transposed B. + // TODO: 2..4x unroll + MulByConstTo(att[start_pos], v.Row(start_pos), att_out, v.Cols(), worker); + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + MulByConstAndAdd(att[pos], v.Row(pos), att_out, v.Cols(), worker); } } else { - for (size_t pos = start_pos; pos <= last_pos; ++pos) { - const size_t pos_modulo = div_seq_len.Remainder(pos); - const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo); - MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols()); + { + const size_t pos_mod = div_seq_len.Remainder(start_pos); + MulByConstTo(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); + } + for (size_t pos = start_pos + 1; pos <= last_pos; ++pos) { + const size_t pos_mod = div_seq_len.Remainder(pos); + MulByConstAndAdd(att[pos_mod], v.Row(pos_mod), att_out, v.Cols(), worker); } } } // Calculates the attention outputs for a single q, which may be updated // in place for RMSNorm. -void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos, - const size_t last_pos, float* HWY_RESTRICT q, - const MatPtrT& k, const MatPtrT& v, - const size_t layer_idx, - const LayerWeightsPtrs& layer, - const AttentionActivations& activations, - float* HWY_RESTRICT att, - float* HWY_RESTRICT att_out) { +void SingleDotSoftmaxWeightedSum( + const size_t pos, const size_t start_pos, const size_t last_pos, + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, + const size_t layer_idx, const LayerWeightsPtrs& layer, + const AttentionActivations& activations, float* HWY_RESTRICT att, + float* HWY_RESTRICT att_out, const size_t worker) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; const size_t seq_len = @@ -136,20 +138,22 @@ void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos, if (layer.query_norm_scale.HasPtr()) { CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), 0, q, - layer.layer_config.qkv_dim); + layer.layer_config.qkv_dim, worker); }); } - PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale); + PositionalEncodingQK(q, layer_idx, layer, activations, worker, pos, + query_scale); - QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att); + QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, worker); // SoftMax with optional SoftCap yields "probabilities" in att. const size_t att_len = HWY_MIN(last_pos + 1, seq_len); - MaybeLogitsSoftCap(att_cap, att, att_len); - Softmax(att, att_len); + MaybeLogitsSoftCap(att_cap, att, att_len, worker); + Softmax(att, att_len, /*temperature=*/1.0f, worker); - WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out); + WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, + worker); } // The attention window usually starts at 0 unless `pos` is larger than @@ -179,75 +183,52 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t seq_len = static_cast(activations.div_seq_len.GetDivisor()); + // All layers should have the same number of heads. + HWY_DASSERT(activations.div_heads.GetDivisor() == layer_config.heads); // For each head/token/query, compute Q.K, softmax, and weighted V. - - // Statically partition token/query across packages. - const size_t num_tq = num_tokens * div_qbatch.GetDivisor(); - const IndexRangePartition tq_ranges = - StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1); - ParallelizeOneRange( - tq_ranges, pools.AllPackages(), - [&](const IndexRange& tq_range, const size_t pkg_idx) { - const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage(); - pools.AllClusters(pkg_idx).Run( - tq_range.begin(), tq_range.end(), - [&](const size_t tq_idx, const size_t cluster_idx) { - const HWY_MAYBE_UNUSED size_t cluster_base = - pkg_base + cluster_idx * pools.MaxWorkersPerCluster(); - const size_t qi = div_qbatch.Remainder(tq_idx); - const size_t batch_idx = div_qbatch.Divide(tq_idx); - auto& kv_cache = qbatch.KV(qi).kv_cache; - - // Find the token position in the query and calculate - // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + batch_idx; - const size_t start_pos = - StartPos(pos, activations.config, layer_idx); - size_t last_pos = pos; - const size_t prefix_end = qbatch.PrefixEnd(qi); - if (prefix_end > 0 && prefix_end - 1 > last_pos) { - // last_pos in QDotK and WeightedSumV is inclusive. - last_pos = prefix_end - 1; - } - - pools.Cluster(pkg_idx, cluster_idx) - .Run( - 0, layer_config.heads, - [&](const size_t head, size_t thread) HWY_ATTR { + const auto func = [&](const size_t task, size_t worker) HWY_ATTR { + const size_t tq_idx = activations.div_heads.Divide(task); + const size_t head = activations.div_heads.Remainder(task); #if PROFILER_ENABLED - const hwy::Zone zone(cluster_base + thread, - zone_id_par); + const hwy::Zone zone(worker, zone_id_par); #endif - const size_t head_offset = - (head / kHeadGroups) * qkv_dim * 2; + const size_t qi = div_qbatch.Remainder(tq_idx); + const size_t batch_idx = div_qbatch.Divide(tq_idx); + auto& kv_cache = qbatch.KV(qi).kv_cache; - float* HWY_RESTRICT q = - activations.q.Row(tq_idx) + head * qkv_dim; + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t start_pos = StartPos(pos, activations.config, layer_idx); + size_t last_pos = pos; + const size_t prefix_end = qbatch.PrefixEnd(qi); + if (prefix_end > 0 && prefix_end - 1 > last_pos) { + // last_pos in QDotK and WeightedSumV is inclusive. + last_pos = prefix_end - 1; + } - float* HWY_RESTRICT att = - activations.att.Row(tq_idx) + head * seq_len; - float* HWY_RESTRICT att_out = - activations.att_out.Row(tq_idx) + head * qkv_dim; + float* HWY_RESTRICT q = activations.q.Row(tq_idx) + head * qkv_dim; + float* HWY_RESTRICT att = activations.att.Row(tq_idx) + head * seq_len; + float* HWY_RESTRICT att_out = + activations.att_out.Row(tq_idx) + head * qkv_dim; - // Make strided read-only views into the kv cache for - // this query and head. - const size_t kv_head_offset = - layer_idx * cache_layer_size + head_offset; - MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); - k.SetPtr(kv_cache.Row(0) + kv_head_offset, - kv_cache.Stride()); - MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); - v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, - kv_cache.Stride()); + // Make strided read-only views into the kv cache for + // this query and head. + const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; + const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); - SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, - k, v, layer_idx, layer, - activations, att, att_out); - }); - }); - }); + SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, + layer, activations, att, att_out, worker); + }; + + ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools, + /*pkg_idx=*/0, func); } // Different functions use different naming conventions for the number of @@ -286,10 +267,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t cache_pos = activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); - env.row_ptrs[0][interleaved_idx] = reinterpret_cast( + env.row_ptrs[2][interleaved_idx] = reinterpret_cast( qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); } - kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); + kv_rows.AttachRowPtrs(env.row_ptrs[2].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, /*add=*/nullptr, env, kv_rows); @@ -298,7 +279,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // tasks are very lightweight. env.ctx.pools.Pool(0).Run( 0, kv_heads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + [&](uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; const size_t qi = div_qbatch.Remainder(interleaved_idx); @@ -318,11 +299,13 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Apply further processing to K. if (layer.key_norm_scale.HasPtr()) { CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim); + RMSNormInplace(weights_t->PackedScale1(), 0, kv_f32, qkv_dim, + thread); }); } - PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos); + PositionalEncodingQK(kv_f32, layer_idx, layer, activations, thread, + pos); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); diff --git a/gemma/attention.h b/gemma/attention.h index d00e81d..5419b7f 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -33,7 +33,7 @@ namespace gcpp { float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ const AttentionActivations& activations, float* HWY_RESTRICT att, \ - float* HWY_RESTRICT att_out); \ + float* HWY_RESTRICT att_out, size_t worker); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ const LayerWeightsPtrs& layer, \ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index e89a3ad..7aa3318 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -23,6 +23,7 @@ #include "gemma/weights.h" #include "ops/matmul.h" #include "util/mat.h" +#include "util/threading.h" #include "hwy/profiler.h" // Include guard (still compiled once per target) @@ -43,9 +44,10 @@ namespace gcpp { namespace HWY_NAMESPACE { template -HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, - const T* HWY_RESTRICT c2, size_t count) { - PROFILER_ZONE("Gen.Activation"); +void Activation(ActivationType activation, T* HWY_RESTRICT c1, + const T* HWY_RESTRICT c2, const size_t count, + const size_t worker) { + PROFILER_ZONE2(worker, "Gen.Activation"); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -62,29 +64,33 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, // No C2 multiplier. template -void ActivationBatched(ActivationType activation, Mat& c1) { +void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) { using T = typename Mat::T; - for (size_t i = 0; i < c1.Rows(); ++i) { - // Cast to correct type so type deduction works. - Activation(activation, c1.Row(i), static_cast(nullptr), - c1.Cols()); - } + ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, + [&](uint64_t task, size_t worker) { + // Cast to correct type so type deduction works. + Activation(activation, c1.Row(task), + static_cast(nullptr), c1.Cols(), worker); + }); } template HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, - const Mat* c2) { + const Mat* c2, NestedPools& pools) { using T = typename Mat::T; HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { - for (size_t i = 0; i < c1.Rows(); ++i) { - Activation(activation, c1.Row(i), c2->Row(i), c1.Cols()); - } + ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), + worker); + }); } else { // No multiplier - for (size_t i = 0; i < c1.Rows(); ++i) { - Activation(activation, c1.Row(i), static_cast(nullptr), - c1.Cols()); - } + ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), + static_cast(nullptr), c1.Cols(), worker); + }); } } @@ -126,7 +132,8 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, activations.C2); // Activation (Gelu) and maybe multiply by gate. Store activations in act. - ActivationBatched(layer_config.activation, activations.C1, &activations.C2); + ActivationBatched(layer_config.activation, activations.C1, &activations.C2, + env.ctx.pools); // Hidden layer -> output layer. CallMatMul(activations.C1, layer.linear_w, output_bias, env, diff --git a/gemma/vit.cc b/gemma/vit.cc index 0231a5f..e96c61e 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -224,7 +224,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, activations.C1); // Activation (Gelu), store in C1. - ActivationBatched(layer_config.activation, activations.C1); + ActivationBatched(layer_config.activation, activations.C1, env.ctx.pools); // Hidden layer -> output layer. CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env, diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 949f445..3de3b76 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -108,6 +108,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // BindB/C if there is a single package: they will be a no-op. BindB(b_trans, sizeof(TC), env.parallel); BindC(C, env.parallel); + C.AllocateAndAttachRowPtrs(env.row_ptrs); Tristate use_spinning = Tristate::kDefault; env.ctx.pools.MaybeStartSpinning(use_spinning); @@ -139,8 +140,8 @@ using SFP = SfpStream; void BenchAllMatMul() { if (first_target == 0) first_target = HWY_TARGET; - // Disable the best-target-only limitation. - // if (HWY_TARGET != first_target) return; + // Comment out to disable the best-target-only limitation. + if (HWY_TARGET != first_target) return; // Skip EMU128 (10x slower than SSE4 for SFP) and older x86. if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 || diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 4465647..1e59165 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1320,20 +1320,7 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C) { - RowPtrs C_rows(C.GetRowPtrs()); - if (HWY_UNLIKELY(!C.GetRowPtrs())) { - if constexpr (HWY_IS_DEBUG_BUILD) { - fprintf(stderr, - "MatMul perf warning: setting row pointers because " - "%s.AttachRowPtrs() was not called.\n", - C.Name()); - } - HWY_DASSERT(C.HasPtr()); - for (size_t r = 0; r < C.Rows(); ++r) { - env.row_ptrs[0][r] = reinterpret_cast(C.Row(r)); - } - C_rows = RowPtrs(env.row_ptrs[0].get()); - } + RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[2]); const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); diff --git a/ops/matmul.cc b/ops/matmul.cc index 3da5512..de6d52b 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -428,7 +428,9 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); - row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); + row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // A + row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxN)); // B + row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // C } void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { diff --git a/ops/matmul.h b/ops/matmul.h index 9c20177..fed3bc6 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -682,10 +682,11 @@ struct MatMulEnv { // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV // writes to differing KV positions per query / output row. - // The first entry is sufficient for any MatMul, but also potentially - // overwritten by each MatMul. Subsequent entries are precomputed for tensors - // and not overwritten. Per-tensor allocations make it likelier that asan - // detects bugs such as use after free, overrun, and dangling references. + // The first three allocations are sufficient for any A, B, C, respectively, + // but also potentially overwritten by each MatMul. Subsequent entries are + // precomputed for tensors and not overwritten. Per-tensor allocations make + // it likelier that asan detects bugs such as use after free, overrun, and + // dangling references. std::vector> row_ptrs; }; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index fb6d09a..3a8132e 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -191,8 +191,9 @@ namespace detail { // Shared by RMSNorm and RMSNormInplace. template -float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) { - PROFILER_ZONE("ops.RMSNormMul"); +float RMSNormMul(const VT* HWY_RESTRICT x, const size_t size, + const HWY_MAYBE_UNUSED size_t worker) { + PROFILER_ZONE2(worker, "ops.RMSNormMul"); const hn::ScalableTag d; const float l2 = DecompressAndCall(d, MakeSpan(x, size), DotKernelDefault()); @@ -204,18 +205,18 @@ float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) { // `x_ofs` is the offset within `x`, required for NuqStream. template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, - const WT* HWY_RESTRICT weight, - size_t w_ofs, OT* HWY_RESTRICT out, - const size_t size) { - PROFILER_ZONE("ops.RMSNorm"); +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs, + OT* HWY_RESTRICT out, const size_t size, + const size_t HWY_MAYBE_UNUSED worker = 0) { + PROFILER_ZONE2(worker, "ops.RMSNorm"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; using VF = hn::Vec; const size_t NF = hn::Lanes(df); - const VF mul = hn::Set(df, detail::RMSNormMul(x, size)); + const VF mul = hn::Set(df, detail::RMSNormMul(x, size, worker)); const auto packed_x = MakeSpan(x, size); const auto packed_w = MakeSpan(weight, w_ofs + size); @@ -237,18 +238,17 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, // Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer. template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, - size_t w_ofs, - XT* HWY_RESTRICT inout, - const size_t size) { - PROFILER_ZONE("ops.RMSNormInplace"); +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout, + const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { + PROFILER_ZONE2(worker, "ops.RMSNormInplace"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; using VF = hn::Vec; const size_t NF = hn::Lanes(df); - const VF mul = hn::Set(df, detail::RMSNormMul(inout, size)); + const VF mul = hn::Set(df, detail::RMSNormMul(inout, size, worker)); const auto packed_w = MakeSpan(weight, w_ofs + size); const auto packed_x = MakeSpan(inout, size); @@ -407,8 +407,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( // This overload is called if `post_qk == PostQKType::HalfRope`. static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( float* HWY_RESTRICT x, const size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, const int pos) { - PROFILER_ZONE("ops.Rope"); + const float* HWY_RESTRICT inv_timescale, const int pos, + const size_t HWY_MAYBE_UNUSED worker = 0) { + PROFILER_ZONE2(worker, "ops.Rope"); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -465,8 +466,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, const int pos) { - PROFILER_ZONE("ops.RopeAndMulBy"); + const float* HWY_RESTRICT inv_timescale, const int pos, + const size_t HWY_MAYBE_UNUSED worker = 0) { + PROFILER_ZONE2(worker, "ops.RopeAndMulBy"); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -523,10 +525,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( } template -static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, - float* HWY_RESTRICT out, - const size_t size) { - PROFILER_ZONE("ops.AddFrom"); +static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( + const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size, + const HWY_MAYBE_UNUSED size_t worker = 0) { + PROFILER_ZONE2(worker, "ops.AddFrom"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -621,9 +623,10 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, } template -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x, - size_t size) { - PROFILER_ZONE("ops.MulByConst"); +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst( + const float c, XT* HWY_RESTRICT x, const size_t size, + const HWY_MAYBE_UNUSED size_t worker = 0) { + PROFILER_ZONE2(worker, "ops.MulByConst"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; const size_t NF = hn::Lanes(df); @@ -659,12 +662,54 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x, } } +// Same as above, but without a separate output. Same as below without the add. template -HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c, - const XT* HWY_RESTRICT x, - OT* HWY_RESTRICT out, - size_t size) { - PROFILER_ZONE("ops.MulByConstAndAdd"); +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( + const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, + const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { + PROFILER_ZONE2(worker, "ops.MulByConstTo"); + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; + + const VF v_c = hn::Set(df, c); + const auto packed_x = MakeSpan(x, size); + const auto packed_out = MakeSpan(out, size); + + size_t i = 0; + if (size >= 2 * NF) { + for (; i <= size - 2 * NF; i += 2 * NF) { + VF x0, x1; + Decompress2(df, packed_x, i, x0, x1); + const VF out0 = hn::Mul(x0, v_c); + const VF out1 = hn::Mul(x1, v_c); + Compress2(df, out0, out1, packed_out, i); + } + } + + const size_t remaining = size - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_x + NF); + DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); + const VF x0 = hn::Load(df, buf_x); + const VF x1 = hn::Load(df, buf_x + NF); + const VF out0 = hn::Mul(x0, v_c); + const VF out1 = hn::Mul(x1, v_c); + Compress2(df, out0, out1, MakeSpan(buf_out, 2 * NF), 0); + hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT)); + } +} + +template +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( + const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, + const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { + PROFILER_ZONE2(worker, "ops.MulByConstAndAdd"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; const size_t NF = hn::Lanes(df); @@ -709,8 +754,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c, // See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - float temperature = 1.0f) { - PROFILER_ZONE("ops.Softmax"); + float temperature = 1.0f, + const HWY_MAYBE_UNUSED size_t worker = 0) { + PROFILER_ZONE2(worker, "ops.Softmax"); HWY_DASSERT(size != 0); namespace hn = hwy::HWY_NAMESPACE; @@ -840,11 +886,10 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, return TokenAndProb{.token = argmax.token, .prob = prob}; } -static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - const size_t size, - const size_t max_pos) { - PROFILER_ZONE("ops.LogitsSoftCap"); - HWY_DASSERT(max_pos <= size); +static HWY_NOINLINE void LogitsSoftCap( + const float cap, float* HWY_RESTRICT x, const size_t size, + const HWY_MAYBE_UNUSED size_t worker = 0) { + PROFILER_ZONE2(worker, "ops.LogitsSoftCap"); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -852,22 +897,18 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, const float inv_cap = 1.0f / cap; - hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR { + hn::Transform(D(), x, size, [cap, inv_cap](D d, V v) HWY_ATTR { return hn::Mul(hn::Set(d, cap), hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); }); } -static HWY_INLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - const size_t size) { - LogitsSoftCap(cap, x, size, size); -} - // Calls LogitsSoftCap if cap != 0.0f. static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( - const float cap, float* HWY_RESTRICT x, const size_t size) { + const float cap, float* HWY_RESTRICT x, const size_t size, + const size_t worker = 0) { if (cap != 0.0f) { - LogitsSoftCap(cap, x, size, size); + LogitsSoftCap(cap, x, size, worker); } } diff --git a/util/mat.h b/util/mat.h index 5fe07b6..786351d 100644 --- a/util/mat.h +++ b/util/mat.h @@ -34,7 +34,7 @@ namespace gcpp { // Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used -// for C, in future also for A. +// for C (KV output), in future also for A or even B. template class RowPtrs { public: @@ -317,6 +317,25 @@ class MatPtrT : public MatPtr { } }; +template +RowPtrs GetOrSetTempRowPtrs( + const MatPtrT& mat, + const hwy::AlignedFreeUniquePtr& storage) { + if (HWY_LIKELY(mat.GetRowPtrs())) return RowPtrs(mat.GetRowPtrs()); + + if constexpr (HWY_IS_DEBUG_BUILD) { + fprintf(stderr, + "MatMul perf warning: setting row pointers because " + "%s.AttachRowPtrs() was not called.\n", + mat.Name()); + } + HWY_DASSERT(mat.HasPtr()); + for (size_t r = 0; r < mat.Rows(); ++r) { + storage[r] = reinterpret_cast(const_cast(mat.Row(r))); + } + return RowPtrs(storage.get()); +} + // Calls `func` with `MatPtrT*` plus the optional `args`. This supports all // types used as weights. template diff --git a/util/threading.h b/util/threading.h index 5e13dae..efb536f 100644 --- a/util/threading.h +++ b/util/threading.h @@ -321,6 +321,40 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, }); } +// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes +// over clusters of ONE package, then within each cluster. +template +void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, + const Func& func) { + const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage(); + + // If few tasks, run on a single cluster. Also avoids a bit of overhead if + // there is only one cluster. + hwy::ThreadPool& all_clusters = pools.AllClusters(pkg_idx); + const size_t num_clusters = all_clusters.NumWorkers(); + hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, 0); + if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { + return cluster.Run(0, num_tasks, [&](uint64_t task, size_t thread) { + func(task, pkg_base + thread); + }); + } + + // Assign each cluster a sub-range. + const IndexRangePartition ranges = + StaticPartition(IndexRange(0, num_tasks), num_clusters, 1); + ParallelizeOneRange( + ranges, all_clusters, + [&](const IndexRange& range, const size_t cluster_idx) { + hwy::ThreadPool& cluster = pools.Cluster(pkg_idx, cluster_idx); + const size_t cluster_base = + pkg_base + cluster_idx * pools.MaxWorkersPerCluster(); + cluster.Run(range.begin(), range.end(), + [&](uint64_t task, size_t thread) { + func(task, cluster_base + thread); + }); + }); +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_