1.02x batch decode speedup: BF16 KV cache

ops-inl.h: Vectorize Rope(), template
Remove unused MulBy, and extra-arg overloads of MulByConst and Softmax
Fix for DecompressAndZeroPad: ensure second vector filled

PiperOrigin-RevId: 772779163
This commit is contained in:
Jan Wassenberg 2025-06-17 23:21:24 -07:00 committed by Copybara-Service
parent 606e22155a
commit 343482c7ef
8 changed files with 214 additions and 155 deletions

View File

@ -449,6 +449,7 @@ cc_library(
srcs = ["gemma/kv_cache.cc"],
hdrs = ["gemma/kv_cache.h"],
deps = [
":basics",
":configs",
":gemma_args",
":mat",
@ -504,6 +505,7 @@ cc_library(
":threading",
":threading_context",
":weights",
"//compression:compress",
"//compression:types",
"//io:blob_store",
"//io",

View File

@ -517,8 +517,19 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
}
}
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and
// RMSNormInplace for the two output types.
// Same as above, but without parallelization nor benchmarking.
template <typename Packed>
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
CompressPerThread& tls,
const PackedSpan<Packed>& packed,
const size_t packed_ofs) {
packed.BoundsCheck(packed_ofs, num);
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
const hn::ScalableTag<float> df;
Traits::Compress(df, raw, num, tls, packed, packed_ofs);
}
// Stores two f32 vectors to f32 or bf16.
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
const size_t packed_ofs) {

View File

@ -24,6 +24,7 @@
#endif // HWY_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
@ -39,6 +40,7 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
@ -50,7 +52,7 @@ namespace HWY_NAMESPACE {
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT q,
const MatPtrT<float>& k, float* HWY_RESTRICT att) {
const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
// Slightly faster: no wraparound.
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
@ -66,8 +68,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
}
}
template <typename U>
static void PositionalEncodingQK(U* qk, const size_t layer_idx,
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
const size_t pos, const float mul = 1.0f) {
@ -97,7 +98,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
const size_t last_pos,
const hwy::Divisor& div_seq_len,
const float* HWY_RESTRICT att,
const MatPtrT<float>& v,
const MatPtrT<BF16>& v,
float* HWY_RESTRICT att_out) {
const size_t qkv_dim = v.Cols();
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
@ -110,7 +111,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
} else {
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
const size_t pos_modulo = div_seq_len.Remainder(pos);
const float* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
}
}
@ -118,12 +119,14 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
// Calculates the attention outputs for a single q, which may be updated
// in place for RMSNorm.
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) {
void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
const size_t last_pos, float* HWY_RESTRICT q,
const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
@ -136,6 +139,7 @@ void SingleDotSoftmaxWeightedSum(
layer.layer_config.qkv_dim);
});
}
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
@ -220,10 +224,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// this query and head.
const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset;
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
kv_cache.Stride());
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
kv_cache.Stride());
@ -263,8 +267,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
// because rows are computed modulo seq_len.
MatPtrT<float> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
layer.qkv_einsum_w2.Rows()));
MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t qi = div_qbatch.Remainder(interleaved_idx);
@ -291,9 +295,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = qbatch.KV(qi).kv_cache;
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
@ -302,7 +306,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
});
}
PositionalEncodingQK(kv, layer_idx, layer, activations, pos);
HWY_ALIGN float kv_f32[kMaxQKVDim];
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim);
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
CompressPerThread tls;
Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0);
});
}

View File

@ -26,25 +26,25 @@
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \
QBatch& qbatch, NestedPools& pools); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \
QBatch& qbatch, NestedPools& pools); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the

View File

@ -32,6 +32,7 @@
namespace gcpp {
static constexpr size_t kMaxConv1DWidth = 4;
static constexpr size_t kMaxQKVDim = 1024;
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping {

View File

@ -19,7 +19,8 @@
#include <stddef.h>
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h"
#include "gemma/gemma_args.h" // InferenceArgs
#include "util/basics.h" // BF16
#include "util/mat.h"
namespace gcpp {
@ -41,7 +42,7 @@ struct KVCache {
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
MatStorageT<float> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private:
// For use by other ctor and Copy()

View File

@ -25,7 +25,7 @@
#include <string>
#include "compression/types.h"
#include "gemma/configs.h" // ModelConfig
#include "gemma/configs.h" // ModelConfig, kMaxQKVDim
#include "gemma/tensor_info.h"
#include "gemma/tokenizer.h"
#include "io/blob_store.h"
@ -234,6 +234,11 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
HWY_ASSERT(config.weight != Type::kUnknown);
for (const LayerConfig& layer_config : config.layer_configs) {
if (static_cast<size_t>(layer_config.qkv_dim) > kMaxQKVDim) {
HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim);
}
}
// We trust the deserialized config, but checking helps to validate the
// deduction, which we rely on below for pre-2025 files.

View File

@ -341,6 +341,10 @@ HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
// Ensure the second vectors are zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
hn::Store(hn::Zero(df), df, buf_scale + NF);
hn::Store(hn::Zero(df), df, buf_bias + NF);
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
@ -399,94 +403,122 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
of this rotation matrix which is simply the same matrix with -pos parameter)
*/
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
// This overload is called if kUseHalfRope.
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
// This overload is called if `post_qk == PostQKType::HalfRope`.
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.Rope");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
}
}
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
PROFILER_ZONE("ops.RopeAndMulBy");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
const D d;
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
const VF vpos = hn::Set(df, static_cast<float>(pos));
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
size_t dim = 0;
for (; dim < vectorizable_dims; dim += hn::Lanes(d)) {
// Compute thetas
V pos_vec = hn::Set(d, pos);
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
for (; dim < vectorizable_dims; dim += NF) {
const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
const VF vtheta = hn::Mul(vpos, vinv_time_scale);
// Compute rotations.
V cos_theta_vec;
V sin_theta_vec;
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
VF vcos_theta;
VF vsin_theta;
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant.
V mul_vec = hn::Set(d, mul);
V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim));
V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv));
// Scale input with rotations.
VF vx0 = hn::LoadU(df, x + dim);
VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec,
hn::Mul(x1_vec, sin_theta_vec));
V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec,
hn::Mul(x1_vec, cos_theta_vec));
// Store
hn::StoreU(xout_0_vec, d, x + dim);
hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv);
hn::StoreU(vx0, df, x + dim);
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
}
// Vectorize computation for remaining dims - same as above, but with LoadN.
const size_t remaining_dims = half_dim_qkv - dim;
HWY_DASSERT(remaining_dims < hn::Lanes(d)); // at most one iteration
HWY_DASSERT(remaining_dims < NF); // at most one iteration
if (remaining_dims != 0) {
// Compute thetas
V pos_vec = hn::Set(d, pos);
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
VF vtheta = hn::Mul(vpos, vinv_time_scale);
// Compute rotations.
V cos_theta_vec;
V sin_theta_vec;
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
VF vcos_theta;
VF vsin_theta;
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations.
VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreN(vx0, df, x + dim, remaining_dims);
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
}
}
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, const int pos) {
PROFILER_ZONE("ops.RopeAndMulBy");
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
const VF vmul = hn::Set(df, mul);
const VF vpos = hn::Set(df, static_cast<float>(pos));
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
size_t dim = 0;
for (; dim < vectorizable_dims; dim += NF) {
const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
const VF vtheta = hn::Mul(vpos, vinv_time_scale);
// Compute rotations.
VF vcos_theta;
VF vsin_theta;
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant.
V mul_vec = hn::Set(d, mul);
V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims));
V x1_vec =
hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims));
VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
V xout_0_vec =
hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec));
V xout_1_vec =
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
hn::StoreU(vx0, df, x + dim);
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
}
// Store
hn::StoreN(xout_0_vec, d, x + dim, remaining_dims);
hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims);
// Vectorize computation for remaining dims - same as above, but with LoadN.
const size_t remaining_dims = half_dim_qkv - dim;
HWY_DASSERT(remaining_dims < NF); // at most one iteration
if (remaining_dims != 0) {
VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
VF vtheta = hn::Mul(vpos, vinv_time_scale);
// Compute rotations.
VF vcos_theta;
VF vsin_theta;
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
// Scale input with rotations and multiply with constant.
VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
VF vx1 =
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
hn::StoreN(vx0, df, x + dim, remaining_dims);
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
}
}
@ -521,6 +553,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
HWY_DASSERT(remaining1 < NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
const VF x0 = hn::Load(df, buf_x);
const VF x1 = hn::Load(df, buf_x + NF);
@ -586,42 +620,43 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
}
}
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, const size_t size,
const size_t max_pos) {
PROFILER_ZONE("ops.MulBy");
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
hn::Transform1(D(), x, max_pos, other,
[](const auto d, const V x, const V other)
HWY_ATTR { return hn::Mul(x, other); });
}
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x,
const size_t size) {
return MulBy(other, x, size, size);
}
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
const size_t size, const size_t max_pos) {
template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
size_t size) {
PROFILER_ZONE("ops.MulByConst");
HWY_DASSERT(max_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR {
return hn::Mul(x, hn::Set(d, c));
});
}
const hn::ScalableTag<float> df;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
float* HWY_RESTRICT x,
const size_t size) {
MulByConst(c, x, size, size);
const VF v_c = hn::Set(df, c);
const auto packed_x = MakeSpan(x, size);
size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
VF x0, x1;
Decompress2(df, packed_x, i, x0, x1);
x0 = hn::Mul(x0, v_c);
x1 = hn::Mul(x1, v_c);
Compress2(df, x0, x1, packed_x, i);
}
}
const size_t remaining = size - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
VF x0 = hn::Load(df, buf_x);
VF x1 = hn::Load(df, buf_x + NF);
x0 = hn::Mul(x0, v_c);
x1 = hn::Mul(x1, v_c);
Compress2(df, x0, x1, MakeSpan(buf_x, 2 * NF), 0);
hwy::CopyBytes(buf_x, x + i, remaining * sizeof(XT));
}
}
template <typename XT, typename OT>
@ -656,6 +691,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
// Ensure the second vectors are zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_x + NF);
hn::Store(hn::Zero(df), df, buf_out + NF);
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
const VF x0 = hn::Load(df, buf_x);
@ -671,11 +709,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
// See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos,
float temperature = 1.0f) {
PROFILER_ZONE("ops.Softmax");
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
@ -685,13 +721,13 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const V vmin = hn::Set(d, hwy::LowestValue<float>());
V vmax = vmin;
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
hn::Foreach(d, x, mask_pos, vmin,
[pmax](const auto d, const V value)
HWY_ATTR { *pmax = hn::Max(*pmax, value); });
hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR {
*pmax = hn::Max(*pmax, value);
});
vmax = hn::MaxOfLanes(d, vmax);
// Subtract max (avoid precision loss for large exponents) and exponentiate.
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR {
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
return hn::CallExp(d, hn::Sub(value, *pmax));
@ -702,7 +738,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
if (temperature != 1.0f) {
const float temperature_inv = 1.0f / temperature;
hn::Transform(d, x, mask_pos,
hn::Transform(d, x, size,
[temperature_inv](const auto d, const V value) HWY_ATTR {
return hn::Mul(value, hn::Set(d, temperature_inv));
});
@ -712,16 +748,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
// not make a huge difference. It halves the standard deviation of the sum of
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
// the generated text after a few hundred tokens.
const float sum_exp = Sum(d, x, mask_pos);
const float sum_exp = Sum(d, x, size);
// Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp;
MulByConst(mul, x, size, mask_pos);
}
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
const size_t size,
float temperature = 1.0f) {
Softmax(x, size, size, temperature);
MulByConst(mul, x, size);
}
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /