From 343482c7efb9b562c19d42a71eb3457385c50de3 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 17 Jun 2025 23:21:24 -0700 Subject: [PATCH] 1.02x batch decode speedup: BF16 KV cache ops-inl.h: Vectorize Rope(), template Remove unused MulBy, and extra-arg overloads of MulByConst and Softmax Fix for DecompressAndZeroPad: ensure second vector filled PiperOrigin-RevId: 772779163 --- BUILD.bazel | 2 + compression/compress-inl.h | 15 ++- gemma/attention.cc | 47 ++++--- gemma/attention.h | 38 +++--- gemma/configs.h | 1 + gemma/kv_cache.h | 5 +- gemma/model_store.cc | 7 +- ops/ops-inl.h | 254 +++++++++++++++++++++---------------- 8 files changed, 214 insertions(+), 155 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 64c68f2..08fcd74 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -449,6 +449,7 @@ cc_library( srcs = ["gemma/kv_cache.cc"], hdrs = ["gemma/kv_cache.h"], deps = [ + ":basics", ":configs", ":gemma_args", ":mat", @@ -504,6 +505,7 @@ cc_library( ":threading", ":threading_context", ":weights", + "//compression:compress", "//compression:types", "//io:blob_store", "//io", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index c5d5129..512f8fa 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -517,8 +517,19 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, } } -// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and -// RMSNormInplace for the two output types. +// Same as above, but without parallelization nor benchmarking. +template +HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, + CompressPerThread& tls, + const PackedSpan& packed, + const size_t packed_ofs) { + packed.BoundsCheck(packed_ofs, num); + using Traits = CompressTraits>; + const hn::ScalableTag df; + Traits::Compress(df, raw, num, tls, packed, packed_ofs); +} + +// Stores two f32 vectors to f32 or bf16. template > void Compress2(DF df, VF raw0, VF raw1, const PackedSpan& packed, const size_t packed_ofs) { diff --git a/gemma/attention.cc b/gemma/attention.cc index ceb80ad..84e41b2 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -24,6 +24,7 @@ #endif // HWY_DISABLED_TARGETS #include "gemma/activations.h" +#include "gemma/configs.h" // kMaxQKVDim #include "gemma/gemma.h" #include "gemma/weights.h" #include "util/threading.h" @@ -39,6 +40,7 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h +#include "compression/compress-inl.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); @@ -50,7 +52,7 @@ namespace HWY_NAMESPACE { static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT q, - const MatPtrT& k, float* HWY_RESTRICT att) { + const MatPtrT& k, float* HWY_RESTRICT att) { if (HWY_LIKELY(last_pos < static_cast(div_seq_len.GetDivisor()))) { // Slightly faster: no wraparound. for (size_t pos = start_pos; pos <= last_pos; ++pos) { @@ -66,8 +68,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, } } -template -static void PositionalEncodingQK(U* qk, const size_t layer_idx, +static void PositionalEncodingQK(float* qk, const size_t layer_idx, const LayerWeightsPtrs& layer, const AttentionActivations& activations, const size_t pos, const float mul = 1.0f) { @@ -97,7 +98,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t last_pos, const hwy::Divisor& div_seq_len, const float* HWY_RESTRICT att, - const MatPtrT& v, + const MatPtrT& v, float* HWY_RESTRICT att_out) { const size_t qkv_dim = v.Cols(); hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); @@ -110,7 +111,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, } else { for (size_t pos = start_pos; pos <= last_pos; ++pos) { const size_t pos_modulo = div_seq_len.Remainder(pos); - const float* HWY_RESTRICT v_ptr = v.Row(pos_modulo); + const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo); MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols()); } } @@ -118,12 +119,14 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, // Calculates the attention outputs for a single q, which may be updated // in place for RMSNorm. -void SingleDotSoftmaxWeightedSum( - const size_t pos, const size_t start_pos, const size_t last_pos, - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, - const size_t layer_idx, const LayerWeightsPtrs& layer, - const AttentionActivations& activations, float* HWY_RESTRICT att, - float* HWY_RESTRICT att_out) { +void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos, + const size_t last_pos, float* HWY_RESTRICT q, + const MatPtrT& k, const MatPtrT& v, + const size_t layer_idx, + const LayerWeightsPtrs& layer, + const AttentionActivations& activations, + float* HWY_RESTRICT att, + float* HWY_RESTRICT att_out) { const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; const size_t seq_len = @@ -136,6 +139,7 @@ void SingleDotSoftmaxWeightedSum( layer.layer_config.qkv_dim); }); } + PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale); QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att); @@ -220,10 +224,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // this query and head. const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; - MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); - MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); @@ -263,8 +267,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Set up MatMul row pointers for writing to KV, which consists of // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound // because rows are computed modulo seq_len. - MatPtrT kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(), - layer.qkv_einsum_w2.Rows())); + MatPtrT kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(), + layer.qkv_einsum_w2.Rows())); for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { const size_t qi = div_qbatch.Remainder(interleaved_idx); @@ -291,9 +295,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t cache_pos = activations.div_seq_len.Remainder(pos); auto& kv_cache = qbatch.KV(qi).kv_cache; - float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + - layer_idx * cache_layer_size + - head * qkv_dim * 2; + BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + + layer_idx * cache_layer_size + + head * qkv_dim * 2; // Apply further processing to K. if (layer.key_norm_scale.HasPtr()) { @@ -302,7 +306,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, }); } - PositionalEncodingQK(kv, layer_idx, layer, activations, pos); + HWY_ALIGN float kv_f32[kMaxQKVDim]; + const hn::ScalableTag df; + DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim); + PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos); + CompressPerThread tls; + Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0); }); } diff --git a/gemma/attention.h b/gemma/attention.h index 589cdb1..d00e81d 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -26,25 +26,25 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void SingleDotSoftmaxWeightedSum( \ - const size_t pos, const size_t start_pos, const size_t last_pos, \ - float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - const AttentionActivations& activations, float* HWY_RESTRICT att, \ - float* HWY_RESTRICT att_out); \ - \ - void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, \ - QBatch& qbatch, NestedPools& pools); \ - \ - void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ - const LayerWeightsPtrs& layer, \ - AttentionActivations& activations, QBatch& qbatch, \ - MatMulEnv& env, int flags); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void SingleDotSoftmaxWeightedSum( \ + const size_t pos, const size_t start_pos, const size_t last_pos, \ + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + const AttentionActivations& activations, float* HWY_RESTRICT att, \ + float* HWY_RESTRICT att_out); \ + \ + void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, \ + QBatch& qbatch, NestedPools& pools); \ + \ + void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/configs.h b/gemma/configs.h index ffc8894..720a788 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -32,6 +32,7 @@ namespace gcpp { static constexpr size_t kMaxConv1DWidth = 4; +static constexpr size_t kMaxQKVDim = 1024; // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class PromptWrapping { diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 8c8d762..aea5120 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -19,7 +19,8 @@ #include #include "gemma/configs.h" // ModelConfig -#include "gemma/gemma_args.h" +#include "gemma/gemma_args.h" // InferenceArgs +#include "util/basics.h" // BF16 #include "util/mat.h" namespace gcpp { @@ -41,7 +42,7 @@ struct KVCache { MatStorageT conv1d_cache; MatStorageT rglru_cache; // [griffin_layers, model_dim] - MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] + MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] private: // For use by other ctor and Copy() diff --git a/gemma/model_store.cc b/gemma/model_store.cc index e17a9fa..d3916bd 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -25,7 +25,7 @@ #include #include "compression/types.h" -#include "gemma/configs.h" // ModelConfig +#include "gemma/configs.h" // ModelConfig, kMaxQKVDim #include "gemma/tensor_info.h" #include "gemma/tokenizer.h" #include "io/blob_store.h" @@ -234,6 +234,11 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader, HWY_ASSERT(config.model != Model::UNKNOWN); HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); HWY_ASSERT(config.weight != Type::kUnknown); + for (const LayerConfig& layer_config : config.layer_configs) { + if (static_cast(layer_config.qkv_dim) > kMaxQKVDim) { + HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim); + } + } // We trust the deserialized config, but checking helps to validate the // deduction, which we rely on below for pre-2025 files. diff --git a/ops/ops-inl.h b/ops/ops-inl.h index c7e5758..6dc1b46 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -341,6 +341,10 @@ HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale, HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)]; + // Ensure the second vectors are zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_x + NF); + hn::Store(hn::Zero(df), df, buf_scale + NF); + hn::Store(hn::Zero(df), df, buf_bias + NF); HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)]; DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining); @@ -399,94 +403,122 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( of this rotation matrix which is simply the same matrix with -pos parameter) */ -// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. -// This overload is called if kUseHalfRope. +// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. +// This overload is called if `post_qk == PostQKType::HalfRope`. static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( float* HWY_RESTRICT x, size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, int pos) { + const float* HWY_RESTRICT inv_timescale, const int pos) { PROFILER_ZONE("ops.Rope"); HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; - for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float theta = StaticCast(pos) * inv_timescale[dim]; - const float cos_val = cosf(theta); - const float sin_val = sinf(theta); - const float x0 = x[dim]; - const float x1 = x[dim + half_dim_qkv]; - x[dim] = x0 * cos_val - x1 * sin_val; - x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val; - } -} -// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. -static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( - const float mul, float* HWY_RESTRICT x, size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, int pos) { - PROFILER_ZONE("ops.RopeAndMulBy"); - HWY_DASSERT(dim_qkv % 2 == 0); - const size_t half_dim_qkv = dim_qkv / 2; - - using D = hn::ScalableTag; - using V = hn::Vec; - const D d; + const hn::ScalableTag df; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; + const VF vpos = hn::Set(df, static_cast(pos)); // Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes) - const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d)); + const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF); size_t dim = 0; - for (; dim < vectorizable_dims; dim += hn::Lanes(d)) { - // Compute thetas - V pos_vec = hn::Set(d, pos); - V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim); - V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec); + for (; dim < vectorizable_dims; dim += NF) { + const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim); + const VF vtheta = hn::Mul(vpos, vinv_time_scale); // Compute rotations. - V cos_theta_vec; - V sin_theta_vec; - hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec); + VF vcos_theta; + VF vsin_theta; + hn::SinCos(df, vtheta, vsin_theta, vcos_theta); - // Scale input with rotations and multiply with constant. - V mul_vec = hn::Set(d, mul); - V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim)); - V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv)); + // Scale input with rotations. + VF vx0 = hn::LoadU(df, x + dim); + VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv); + vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); - V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec, - hn::Mul(x1_vec, sin_theta_vec)); - V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec, - hn::Mul(x1_vec, cos_theta_vec)); - - // Store - hn::StoreU(xout_0_vec, d, x + dim); - hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv); + hn::StoreU(vx0, df, x + dim); + hn::StoreU(vx1, df, x + dim + half_dim_qkv); } // Vectorize computation for remaining dims - same as above, but with LoadN. const size_t remaining_dims = half_dim_qkv - dim; - HWY_DASSERT(remaining_dims < hn::Lanes(d)); // at most one iteration + HWY_DASSERT(remaining_dims < NF); // at most one iteration if (remaining_dims != 0) { - // Compute thetas - V pos_vec = hn::Set(d, pos); - V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims); - V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec); + VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims); + VF vtheta = hn::Mul(vpos, vinv_time_scale); // Compute rotations. - V cos_theta_vec; - V sin_theta_vec; - hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec); + VF vcos_theta; + VF vsin_theta; + hn::SinCos(df, vtheta, vsin_theta, vcos_theta); + + // Scale input with rotations. + VF vx0 = hn::LoadN(df, x + dim, remaining_dims); + VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims); + vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); + + hn::StoreN(vx0, df, x + dim, remaining_dims); + hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims); + } +} + +// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations. +static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( + const float mul, float* HWY_RESTRICT x, size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, const int pos) { + PROFILER_ZONE("ops.RopeAndMulBy"); + HWY_DASSERT(dim_qkv % 2 == 0); + const size_t half_dim_qkv = dim_qkv / 2; + + const hn::ScalableTag df; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; + const VF vmul = hn::Set(df, mul); + const VF vpos = hn::Set(df, static_cast(pos)); + + // Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes) + const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF); + size_t dim = 0; + for (; dim < vectorizable_dims; dim += NF) { + const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim); + const VF vtheta = hn::Mul(vpos, vinv_time_scale); + + // Compute rotations. + VF vcos_theta; + VF vsin_theta; + hn::SinCos(df, vtheta, vsin_theta, vcos_theta); // Scale input with rotations and multiply with constant. - V mul_vec = hn::Set(d, mul); - V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims)); - V x1_vec = - hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims)); + VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim)); + VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv)); + vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); - V xout_0_vec = - hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec)); - V xout_1_vec = - hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec)); + hn::StoreU(vx0, df, x + dim); + hn::StoreU(vx1, df, x + dim + half_dim_qkv); + } - // Store - hn::StoreN(xout_0_vec, d, x + dim, remaining_dims); - hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims); + // Vectorize computation for remaining dims - same as above, but with LoadN. + const size_t remaining_dims = half_dim_qkv - dim; + HWY_DASSERT(remaining_dims < NF); // at most one iteration + if (remaining_dims != 0) { + VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims); + VF vtheta = hn::Mul(vpos, vinv_time_scale); + + // Compute rotations. + VF vcos_theta; + VF vsin_theta; + hn::SinCos(df, vtheta, vsin_theta, vcos_theta); + + // Scale input with rotations and multiply with constant. + VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims)); + VF vx1 = + hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims)); + vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta)); + vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta)); + + hn::StoreN(vx0, df, x + dim, remaining_dims); + hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims); } } @@ -521,6 +553,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, HWY_DASSERT(remaining1 < NF); if (HWY_UNLIKELY(remaining != 0)) { HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_x + NF); DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); const VF x0 = hn::Load(df, buf_x); const VF x1 = hn::Load(df, buf_x + NF); @@ -586,42 +620,43 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, } } -static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, const size_t size, - const size_t max_pos) { - PROFILER_ZONE("ops.MulBy"); - HWY_DASSERT(max_pos <= size); - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - - hn::Transform1(D(), x, max_pos, other, - [](const auto d, const V x, const V other) - HWY_ATTR { return hn::Mul(x, other); }); -} - -static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, - const size_t size) { - return MulBy(other, x, size, size); -} - -static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x, - const size_t size, const size_t max_pos) { +template +HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x, + size_t size) { PROFILER_ZONE("ops.MulByConst"); - HWY_DASSERT(max_pos <= size); namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR { - return hn::Mul(x, hn::Set(d, c)); - }); -} + const hn::ScalableTag df; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; -static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c, - float* HWY_RESTRICT x, - const size_t size) { - MulByConst(c, x, size, size); + const VF v_c = hn::Set(df, c); + const auto packed_x = MakeSpan(x, size); + + size_t i = 0; + if (size >= 2 * NF) { + for (; i <= size - 2 * NF; i += 2 * NF) { + VF x0, x1; + Decompress2(df, packed_x, i, x0, x1); + x0 = hn::Mul(x0, v_c); + x1 = hn::Mul(x1, v_c); + Compress2(df, x0, x1, packed_x, i); + } + } + + const size_t remaining = size - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_x + NF); + DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); + VF x0 = hn::Load(df, buf_x); + VF x1 = hn::Load(df, buf_x + NF); + x0 = hn::Mul(x0, v_c); + x1 = hn::Mul(x1, v_c); + Compress2(df, x0, x1, MakeSpan(buf_x, 2 * NF), 0); + hwy::CopyBytes(buf_x, x + i, remaining * sizeof(XT)); + } } template @@ -656,6 +691,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c, if (HWY_UNLIKELY(remaining != 0)) { HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; + // Ensure the second vectors are zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_x + NF); + hn::Store(hn::Zero(df), df, buf_out + NF); DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); DecompressAndZeroPad(df, packed_out, i, buf_out, remaining); const VF x0 = hn::Load(df, buf_x); @@ -671,11 +709,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c, // See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - const size_t mask_pos, float temperature = 1.0f) { PROFILER_ZONE("ops.Softmax"); HWY_DASSERT(size != 0); - HWY_DASSERT(mask_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -685,13 +721,13 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const V vmin = hn::Set(d, hwy::LowestValue()); V vmax = vmin; V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly - hn::Foreach(d, x, mask_pos, vmin, - [pmax](const auto d, const V value) - HWY_ATTR { *pmax = hn::Max(*pmax, value); }); + hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR { + *pmax = hn::Max(*pmax, value); + }); vmax = hn::MaxOfLanes(d, vmax); // Subtract max (avoid precision loss for large exponents) and exponentiate. - hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR { + hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR { if constexpr (HWY_TARGET & HWY_ALL_SVE) { // Temporary workaround for buggy SVE codegen: avoid inlined Exp(). return hn::CallExp(d, hn::Sub(value, *pmax)); @@ -702,7 +738,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, if (temperature != 1.0f) { const float temperature_inv = 1.0f / temperature; - hn::Transform(d, x, mask_pos, + hn::Transform(d, x, size, [temperature_inv](const auto d, const V value) HWY_ATTR { return hn::Mul(value, hn::Set(d, temperature_inv)); }); @@ -712,16 +748,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, // not make a huge difference. It halves the standard deviation of the sum of // the normalized probabilities from 1E-7 to 5E-8, but actually also changes // the generated text after a few hundred tokens. - const float sum_exp = Sum(d, x, mask_pos); + const float sum_exp = Sum(d, x, size); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; - MulByConst(mul, x, size, mask_pos); -} - -static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, - const size_t size, - float temperature = 1.0f) { - Softmax(x, size, size, temperature); + MulByConst(mul, x, size); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /