1.02x batch decode speedup: BF16 KV cache

ops-inl.h: Vectorize Rope(), template
Remove unused MulBy, and extra-arg overloads of MulByConst and Softmax
Fix for DecompressAndZeroPad: ensure second vector filled

PiperOrigin-RevId: 772779163
This commit is contained in:
Jan Wassenberg 2025-06-17 23:21:24 -07:00 committed by Copybara-Service
parent 606e22155a
commit 343482c7ef
8 changed files with 214 additions and 155 deletions

View File

@ -449,6 +449,7 @@ cc_library(
srcs = ["gemma/kv_cache.cc"], srcs = ["gemma/kv_cache.cc"],
hdrs = ["gemma/kv_cache.h"], hdrs = ["gemma/kv_cache.h"],
deps = [ deps = [
":basics",
":configs", ":configs",
":gemma_args", ":gemma_args",
":mat", ":mat",
@ -504,6 +505,7 @@ cc_library(
":threading", ":threading",
":threading_context", ":threading_context",
":weights", ":weights",
"//compression:compress",
"//compression:types", "//compression:types",
"//io:blob_store", "//io:blob_store",
"//io", "//io",

View File

@ -517,8 +517,19 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
} }
} }
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and // Same as above, but without parallelization nor benchmarking.
// RMSNormInplace for the two output types. template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
packed.BoundsCheck(packed_ofs, num);
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
const hn::ScalableTag<float> df;
Traits::Compress(df, raw, num, tls, packed, packed_ofs);
}
// Stores two f32 vectors to f32 or bf16.
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>> template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed, void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
const size_t packed_ofs) { const size_t packed_ofs) {

View File

@ -24,6 +24,7 @@
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/threading.h" #include "util/threading.h"
@ -39,6 +40,7 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h // After highway.h
#include "compression/compress-inl.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
@ -50,7 +52,7 @@ namespace HWY_NAMESPACE {
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q, const float* HWY_RESTRICT q,
const MatPtrT<float>& k, float* HWY_RESTRICT att) { const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) { if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound. // Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -66,8 +68,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
} }
} }
template <typename U> static void PositionalEncodingQK(float* qk, const size_t layer_idx,
static void PositionalEncodingQK(U* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, const AttentionActivations& activations,
const size_t pos, const float mul = 1.0f) { const size_t pos, const float mul = 1.0f) {
@ -97,7 +98,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
const size_t last_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT att, const float* HWY_RESTRICT att,
const MatPtrT<float>& v, const MatPtrT<BF16>& v,
float* HWY_RESTRICT att_out) { float* HWY_RESTRICT att_out) {
const size_t qkv_dim = v.Cols(); const size_t qkv_dim = v.Cols();
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
@ -110,7 +111,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
} else { } else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) { for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos); const size_t pos_modulo = div_seq_len.Remainder(pos);
const float* HWY_RESTRICT v_ptr = v.Row(pos_modulo); const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols()); MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
} }
} }
@ -118,11 +119,13 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
// Calculates the attention outputs for a single q, which may be updated // Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm. // in place for RMSNorm.
void SingleDotSoftmaxWeightedSum( void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t last_pos, float* HWY_RESTRICT q,
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx,
const AttentionActivations& activations, float* HWY_RESTRICT att, const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) { float* HWY_RESTRICT att_out) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
@ -136,6 +139,7 @@ void SingleDotSoftmaxWeightedSum(
layer.layer_config.qkv_dim); layer.layer_config.qkv_dim);
}); });
} }
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale); PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att); QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
@ -220,10 +224,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// this query and head. // this query and head.
const size_t kv_head_offset = const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset; layer_idx * cache_layer_size + head_offset;
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim)); MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset, k.SetPtr(kv_cache.Row(0) + kv_head_offset,
kv_cache.Stride()); kv_cache.Stride());
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim)); MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
kv_cache.Stride()); kv_cache.Stride());
@ -263,7 +267,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Set up MatMul row pointers for writing to KV, which consists of // Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
// because rows are computed modulo seq_len. // because rows are computed modulo seq_len.
MatPtrT<float> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(), MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
layer.qkv_einsum_w2.Rows())); layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
@ -291,7 +295,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos); const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = qbatch.KV(qi).kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size + layer_idx * cache_layer_size +
head * qkv_dim * 2; head * qkv_dim * 2;
@ -302,7 +306,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
}); });
} }
PositionalEncodingQK(kv, layer_idx, layer, activations, pos); HWY_ALIGN float kv_f32[kMaxQKVDim];
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim);
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
CompressPerThread tls;
Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0);
}); });
} }

View File

@ -30,7 +30,7 @@ namespace gcpp {
namespace NAMESPACE { \ namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \ void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \ float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \ const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \ float* HWY_RESTRICT att_out); \

View File

@ -32,6 +32,7 @@
namespace gcpp { namespace gcpp {
static constexpr size_t kMaxConv1DWidth = 4; static constexpr size_t kMaxConv1DWidth = 4;
static constexpr size_t kMaxQKVDim = 1024;
// Instruction-tuned models require extra 'turn structure' tokens in prompts. // Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping { enum class PromptWrapping {

View File

@ -19,7 +19,8 @@
#include <stddef.h> #include <stddef.h>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16
#include "util/mat.h" #include "util/mat.h"
namespace gcpp { namespace gcpp {
@ -41,7 +42,7 @@ struct KVCache {
MatStorageT<float> conv1d_cache; MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim] MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
MatStorageT<float> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private: private:
// For use by other ctor and Copy() // For use by other ctor and Copy()

View File

@ -25,7 +25,7 @@
#include <string> #include <string>
#include "compression/types.h" #include "compression/types.h"
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig, kMaxQKVDim
#include "gemma/tensor_info.h" #include "gemma/tensor_info.h"
#include "gemma/tokenizer.h" #include "gemma/tokenizer.h"
#include "io/blob_store.h" #include "io/blob_store.h"
@ -234,6 +234,11 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
HWY_ASSERT(config.model != Model::UNKNOWN); HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
HWY_ASSERT(config.weight != Type::kUnknown); HWY_ASSERT(config.weight != Type::kUnknown);
for (const LayerConfig& layer_config : config.layer_configs) {
if (static_cast<size_t>(layer_config.qkv_dim) > kMaxQKVDim) {
HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim);
}
}
// We trust the deserialized config, but checking helps to validate the // We trust the deserialized config, but checking helps to validate the
// deduction, which we rely on below for pre-2025 files. // deduction, which we rely on below for pre-2025 files.

View File

@ -341,6 +341,10 @@ HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
// Ensure the second vectors are zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
hn::Store(hn::Zero(df), df, buf_scale + NF);
hn::Store(hn::Zero(df), df, buf_bias + NF);
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)]; HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining); DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
@ -399,94 +403,122 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
of this rotation matrix which is simply the same matrix with -pos parameter) of this rotation matrix which is simply the same matrix with -pos parameter)
*/ */
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. // `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
// This overload is called if kUseHalfRope. // This overload is called if `post_qk == PostQKType::HalfRope`.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, size_t dim_qkv, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) { const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.Rope"); PROFILER_ZONE("ops.Rope");
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
}
}
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. const hn::ScalableTag<float> df;
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( const size_t NF = hn::Lanes(df);
const float mul, float* HWY_RESTRICT x, size_t dim_qkv, using VF = hn::Vec<decltype(df)>;
const float* HWY_RESTRICT inv_timescale, int pos) { const VF vpos = hn::Set(df, static_cast<float>(pos));
PROFILER_ZONE("ops.RopeAndMulBy");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
const D d;
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes) // Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d)); const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
size_t dim = 0; size_t dim = 0;
for (; dim < vectorizable_dims; dim += hn::Lanes(d)) { for (; dim < vectorizable_dims; dim += NF) {
// Compute thetas const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
V pos_vec = hn::Set(d, pos); const VF vtheta = hn::Mul(vpos, vinv_time_scale);
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
// Compute rotations. // Compute rotations.
V cos_theta_vec; VF vcos_theta;
V sin_theta_vec; VF vsin_theta;
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec); hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant. // Scale input with rotations.
V mul_vec = hn::Set(d, mul); VF vx0 = hn::LoadU(df, x + dim);
V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim)); VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv)); vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec, hn::StoreU(vx0, df, x + dim);
hn::Mul(x1_vec, sin_theta_vec)); hn::StoreU(vx1, df, x + dim + half_dim_qkv);
V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec,
hn::Mul(x1_vec, cos_theta_vec));
// Store
hn::StoreU(xout_0_vec, d, x + dim);
hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv);
} }
// Vectorize computation for remaining dims - same as above, but with LoadN. // Vectorize computation for remaining dims - same as above, but with LoadN.
const size_t remaining_dims = half_dim_qkv - dim; const size_t remaining_dims = half_dim_qkv - dim;
HWY_DASSERT(remaining_dims < hn::Lanes(d)); // at most one iteration HWY_DASSERT(remaining_dims < NF); // at most one iteration
if (remaining_dims != 0) { if (remaining_dims != 0) {
// Compute thetas VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
V pos_vec = hn::Set(d, pos); VF vtheta = hn::Mul(vpos, vinv_time_scale);
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
// Compute rotations. // Compute rotations.
V cos_theta_vec; VF vcos_theta;
V sin_theta_vec; VF vsin_theta;
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec); hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations.
VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreN(vx0, df, x + dim, remaining_dims);
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
}
}
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.RopeAndMulBy");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
const VF vmul = hn::Set(df, mul);
const VF vpos = hn::Set(df, static_cast<float>(pos));
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
size_t dim = 0;
for (; dim < vectorizable_dims; dim += NF) {
const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
const VF vtheta = hn::Mul(vpos, vinv_time_scale);
// Compute rotations.
VF vcos_theta;
VF vsin_theta;
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant. // Scale input with rotations and multiply with constant.
V mul_vec = hn::Set(d, mul); VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims)); VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
V x1_vec = vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims)); vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
V xout_0_vec = hn::StoreU(vx0, df, x + dim);
hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec)); hn::StoreU(vx1, df, x + dim + half_dim_qkv);
V xout_1_vec = }
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
// Store // Vectorize computation for remaining dims - same as above, but with LoadN.
hn::StoreN(xout_0_vec, d, x + dim, remaining_dims); const size_t remaining_dims = half_dim_qkv - dim;
hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims); HWY_DASSERT(remaining_dims < NF); // at most one iteration
if (remaining_dims != 0) {
VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
VF vtheta = hn::Mul(vpos, vinv_time_scale);
// Compute rotations.
VF vcos_theta;
VF vsin_theta;
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant.
VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
VF vx1 =
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreN(vx0, df, x + dim, remaining_dims);
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
} }
} }
@ -521,6 +553,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
HWY_DASSERT(remaining1 < NF); HWY_DASSERT(remaining1 < NF);
if (HWY_UNLIKELY(remaining != 0)) { if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
const VF x0 = hn::Load(df, buf_x); const VF x0 = hn::Load(df, buf_x);
const VF x1 = hn::Load(df, buf_x + NF); const VF x1 = hn::Load(df, buf_x + NF);
@ -586,42 +620,43 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
} }
} }
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, template <typename XT>
float* HWY_RESTRICT x, const size_t size, HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
const size_t max_pos) { size_t size) {
PROFILER_ZONE("ops.MulBy");
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
hn::Transform1(D(), x, max_pos, other,
[](const auto d, const V x, const V other)
HWY_ATTR { return hn::Mul(x, other); });
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x,
const size_t size) {
return MulBy(other, x, size, size);
}
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
const size_t size, const size_t max_pos) {
PROFILER_ZONE("ops.MulByConst"); PROFILER_ZONE("ops.MulByConst");
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; const hn::ScalableTag<float> df;
using V = hn::Vec<D>; const size_t NF = hn::Lanes(df);
hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR { using VF = hn::Vec<decltype(df)>;
return hn::Mul(x, hn::Set(d, c));
});
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c, const VF v_c = hn::Set(df, c);
float* HWY_RESTRICT x, const auto packed_x = MakeSpan(x, size);
const size_t size) {
MulByConst(c, x, size, size); size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
VF x0, x1;
Decompress2(df, packed_x, i, x0, x1);
x0 = hn::Mul(x0, v_c);
x1 = hn::Mul(x1, v_c);
Compress2(df, x0, x1, packed_x, i);
}
}
const size_t remaining = size - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
VF x0 = hn::Load(df, buf_x);
VF x1 = hn::Load(df, buf_x + NF);
x0 = hn::Mul(x0, v_c);
x1 = hn::Mul(x1, v_c);
Compress2(df, x0, x1, MakeSpan(buf_x, 2 * NF), 0);
hwy::CopyBytes(buf_x, x + i, remaining * sizeof(XT));
}
} }
template <typename XT, typename OT> template <typename XT, typename OT>
@ -656,6 +691,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
if (HWY_UNLIKELY(remaining != 0)) { if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)]; HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
// Ensure the second vectors are zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
hn::Store(hn::Zero(df), df, buf_out + NF);
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining); DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
const VF x0 = hn::Load(df, buf_x); const VF x0 = hn::Load(df, buf_x);
@ -671,11 +709,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos,
float temperature = 1.0f) { float temperature = 1.0f) {
PROFILER_ZONE("ops.Softmax"); PROFILER_ZONE("ops.Softmax");
HWY_DASSERT(size != 0); HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
@ -685,13 +721,13 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const V vmin = hn::Set(d, hwy::LowestValue<float>()); const V vmin = hn::Set(d, hwy::LowestValue<float>());
V vmax = vmin; V vmax = vmin;
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
hn::Foreach(d, x, mask_pos, vmin, hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR {
[pmax](const auto d, const V value) *pmax = hn::Max(*pmax, value);
HWY_ATTR { *pmax = hn::Max(*pmax, value); }); });
vmax = hn::MaxOfLanes(d, vmax); vmax = hn::MaxOfLanes(d, vmax);
// Subtract max (avoid precision loss for large exponents) and exponentiate. // Subtract max (avoid precision loss for large exponents) and exponentiate.
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR { hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR {
if constexpr (HWY_TARGET & HWY_ALL_SVE) { if constexpr (HWY_TARGET & HWY_ALL_SVE) {
// Temporary workaround for buggy SVE codegen: avoid inlined Exp(). // Temporary workaround for buggy SVE codegen: avoid inlined Exp().
return hn::CallExp(d, hn::Sub(value, *pmax)); return hn::CallExp(d, hn::Sub(value, *pmax));
@ -702,7 +738,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
if (temperature != 1.0f) { if (temperature != 1.0f) {
const float temperature_inv = 1.0f / temperature; const float temperature_inv = 1.0f / temperature;
hn::Transform(d, x, mask_pos, hn::Transform(d, x, size,
[temperature_inv](const auto d, const V value) HWY_ATTR { [temperature_inv](const auto d, const V value) HWY_ATTR {
return hn::Mul(value, hn::Set(d, temperature_inv)); return hn::Mul(value, hn::Set(d, temperature_inv));
}); });
@ -712,16 +748,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
// not make a huge difference. It halves the standard deviation of the sum of // not make a huge difference. It halves the standard deviation of the sum of
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes // the normalized probabilities from 1E-7 to 5E-8, but actually also changes
// the generated text after a few hundred tokens. // the generated text after a few hundred tokens.
const float sum_exp = Sum(d, x, mask_pos); const float sum_exp = Sum(d, x, size);
// Double-precision reciprocal does not appear to affect the results. // Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp; const float mul = 1.0f / sum_exp;
MulByConst(mul, x, size, mask_pos); MulByConst(mul, x, size);
}
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
const size_t size,
float temperature = 1.0f) {
Softmax(x, size, size, temperature);
} }
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /