mirror of https://github.com/google/gemma.cpp.git
1.02x batch decode speedup: BF16 KV cache
ops-inl.h: Vectorize Rope(), template Remove unused MulBy, and extra-arg overloads of MulByConst and Softmax Fix for DecompressAndZeroPad: ensure second vector filled PiperOrigin-RevId: 772779163
This commit is contained in:
parent
606e22155a
commit
343482c7ef
|
|
@ -449,6 +449,7 @@ cc_library(
|
|||
srcs = ["gemma/kv_cache.cc"],
|
||||
hdrs = ["gemma/kv_cache.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":mat",
|
||||
|
|
@ -504,6 +505,7 @@ cc_library(
|
|||
":threading",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"//compression:types",
|
||||
"//io:blob_store",
|
||||
"//io",
|
||||
|
|
|
|||
|
|
@ -517,8 +517,19 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and
|
||||
// RMSNormInplace for the two output types.
|
||||
// Same as above, but without parallelization nor benchmarking.
|
||||
template <typename Packed>
|
||||
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressPerThread& tls,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
packed.BoundsCheck(packed_ofs, num);
|
||||
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||
const hn::ScalableTag<float> df;
|
||||
Traits::Compress(df, raw, num, tls, packed, packed_ofs);
|
||||
}
|
||||
|
||||
// Stores two f32 vectors to f32 or bf16.
|
||||
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h" // kMaxQKVDim
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
|
|
@ -39,6 +40,7 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -50,7 +52,7 @@ namespace HWY_NAMESPACE {
|
|||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT q,
|
||||
const MatPtrT<float>& k, float* HWY_RESTRICT att) {
|
||||
const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
|
||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||
// Slightly faster: no wraparound.
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
|
|
@ -66,8 +68,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
static void PositionalEncodingQK(U* qk, const size_t layer_idx,
|
||||
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
const size_t pos, const float mul = 1.0f) {
|
||||
|
|
@ -97,7 +98,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const float* HWY_RESTRICT att,
|
||||
const MatPtrT<float>& v,
|
||||
const MatPtrT<BF16>& v,
|
||||
float* HWY_RESTRICT att_out) {
|
||||
const size_t qkv_dim = v.Cols();
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
|
|
@ -110,7 +111,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
} else {
|
||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
||||
const float* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
|
||||
const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
|
||||
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
|
||||
}
|
||||
}
|
||||
|
|
@ -118,12 +119,14 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
|
||||
// Calculates the attention outputs for a single q, which may be updated
|
||||
// in place for RMSNorm.
|
||||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out) {
|
||||
void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
|
||||
const size_t last_pos, float* HWY_RESTRICT q,
|
||||
const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
|
||||
const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out) {
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
|
|
@ -136,6 +139,7 @@ void SingleDotSoftmaxWeightedSum(
|
|||
layer.layer_config.qkv_dim);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
|
||||
|
|
@ -220,10 +224,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
// this query and head.
|
||||
const size_t kv_head_offset =
|
||||
layer_idx * cache_layer_size + head_offset;
|
||||
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
|
||||
kv_cache.Stride());
|
||||
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
|
||||
kv_cache.Stride());
|
||||
|
||||
|
|
@ -263,8 +267,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// Set up MatMul row pointers for writing to KV, which consists of
|
||||
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
|
||||
// because rows are computed modulo seq_len.
|
||||
MatPtrT<float> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
|
||||
layer.qkv_einsum_w2.Rows()));
|
||||
MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
|
||||
layer.qkv_einsum_w2.Rows()));
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
|
|
@ -291,9 +295,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer.key_norm_scale.HasPtr()) {
|
||||
|
|
@ -302,7 +306,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(kv, layer_idx, layer, activations, pos);
|
||||
HWY_ALIGN float kv_f32[kMaxQKVDim];
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim);
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,25 +26,25 @@
|
|||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void SingleDotSoftmaxWeightedSum( \
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, \
|
||||
QBatch& qbatch, NestedPools& pools); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void SingleDotSoftmaxWeightedSum( \
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, \
|
||||
QBatch& qbatch, NestedPools& pools); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
namespace gcpp {
|
||||
|
||||
static constexpr size_t kMaxConv1DWidth = 4;
|
||||
static constexpr size_t kMaxQKVDim = 1024;
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class PromptWrapping {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/gemma_args.h" // InferenceArgs
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -41,7 +42,7 @@ struct KVCache {
|
|||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
||||
|
||||
MatStorageT<float> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
private:
|
||||
// For use by other ctor and Copy()
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/configs.h" // ModelConfig, kMaxQKVDim
|
||||
#include "gemma/tensor_info.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "io/blob_store.h"
|
||||
|
|
@ -234,6 +234,11 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
|
|||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||
for (const LayerConfig& layer_config : config.layer_configs) {
|
||||
if (static_cast<size_t>(layer_config.qkv_dim) > kMaxQKVDim) {
|
||||
HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim);
|
||||
}
|
||||
}
|
||||
|
||||
// We trust the deserialized config, but checking helps to validate the
|
||||
// deduction, which we rely on below for pre-2025 files.
|
||||
|
|
|
|||
254
ops/ops-inl.h
254
ops/ops-inl.h
|
|
@ -341,6 +341,10 @@ HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
|
|||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
|
||||
// Ensure the second vectors are zeroed even if remaining <= NF.
|
||||
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||
hn::Store(hn::Zero(df), df, buf_scale + NF);
|
||||
hn::Store(hn::Zero(df), df, buf_bias + NF);
|
||||
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
|
||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
|
||||
|
|
@ -399,94 +403,122 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
of this rotation matrix which is simply the same matrix with -pos parameter)
|
||||
*/
|
||||
|
||||
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
||||
// This overload is called if kUseHalfRope.
|
||||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||
float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||
PROFILER_ZONE("ops.Rope");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
|
||||
const float cos_val = cosf(theta);
|
||||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
const float x1 = x[dim + half_dim_qkv];
|
||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
PROFILER_ZONE("ops.RopeAndMulBy");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
using D = hn::ScalableTag<float>;
|
||||
using V = hn::Vec<D>;
|
||||
const D d;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const VF vpos = hn::Set(df, static_cast<float>(pos));
|
||||
|
||||
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
||||
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
|
||||
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
|
||||
size_t dim = 0;
|
||||
for (; dim < vectorizable_dims; dim += hn::Lanes(d)) {
|
||||
// Compute thetas
|
||||
V pos_vec = hn::Set(d, pos);
|
||||
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
||||
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
|
||||
for (; dim < vectorizable_dims; dim += NF) {
|
||||
const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
|
||||
const VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||
|
||||
// Compute rotations.
|
||||
V cos_theta_vec;
|
||||
V sin_theta_vec;
|
||||
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
|
||||
VF vcos_theta;
|
||||
VF vsin_theta;
|
||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations and multiply with constant.
|
||||
V mul_vec = hn::Set(d, mul);
|
||||
V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim));
|
||||
V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv));
|
||||
// Scale input with rotations.
|
||||
VF vx0 = hn::LoadU(df, x + dim);
|
||||
VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec,
|
||||
hn::Mul(x1_vec, sin_theta_vec));
|
||||
V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec,
|
||||
hn::Mul(x1_vec, cos_theta_vec));
|
||||
|
||||
// Store
|
||||
hn::StoreU(xout_0_vec, d, x + dim);
|
||||
hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv);
|
||||
hn::StoreU(vx0, df, x + dim);
|
||||
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
||||
}
|
||||
|
||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||
const size_t remaining_dims = half_dim_qkv - dim;
|
||||
HWY_DASSERT(remaining_dims < hn::Lanes(d)); // at most one iteration
|
||||
HWY_DASSERT(remaining_dims < NF); // at most one iteration
|
||||
if (remaining_dims != 0) {
|
||||
// Compute thetas
|
||||
V pos_vec = hn::Set(d, pos);
|
||||
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
|
||||
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
|
||||
VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
|
||||
VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||
|
||||
// Compute rotations.
|
||||
V cos_theta_vec;
|
||||
V sin_theta_vec;
|
||||
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
|
||||
VF vcos_theta;
|
||||
VF vsin_theta;
|
||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations.
|
||||
VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
|
||||
VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
||||
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||
}
|
||||
}
|
||||
|
||||
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||
PROFILER_ZONE("ops.RopeAndMulBy");
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const VF vmul = hn::Set(df, mul);
|
||||
const VF vpos = hn::Set(df, static_cast<float>(pos));
|
||||
|
||||
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
||||
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
|
||||
size_t dim = 0;
|
||||
for (; dim < vectorizable_dims; dim += NF) {
|
||||
const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
|
||||
const VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||
|
||||
// Compute rotations.
|
||||
VF vcos_theta;
|
||||
VF vsin_theta;
|
||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations and multiply with constant.
|
||||
V mul_vec = hn::Set(d, mul);
|
||||
V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims));
|
||||
V x1_vec =
|
||||
hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims));
|
||||
VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
|
||||
VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
V xout_0_vec =
|
||||
hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec));
|
||||
V xout_1_vec =
|
||||
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
|
||||
hn::StoreU(vx0, df, x + dim);
|
||||
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
||||
}
|
||||
|
||||
// Store
|
||||
hn::StoreN(xout_0_vec, d, x + dim, remaining_dims);
|
||||
hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims);
|
||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||
const size_t remaining_dims = half_dim_qkv - dim;
|
||||
HWY_DASSERT(remaining_dims < NF); // at most one iteration
|
||||
if (remaining_dims != 0) {
|
||||
VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
|
||||
VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||
|
||||
// Compute rotations.
|
||||
VF vcos_theta;
|
||||
VF vsin_theta;
|
||||
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||
|
||||
// Scale input with rotations and multiply with constant.
|
||||
VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
|
||||
VF vx1 =
|
||||
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
|
||||
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||
|
||||
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
||||
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -521,6 +553,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
|||
HWY_DASSERT(remaining1 < NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||
// Ensure the second vector is zeroed even if remaining <= NF.
|
||||
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||
const VF x0 = hn::Load(df, buf_x);
|
||||
const VF x1 = hn::Load(df, buf_x + NF);
|
||||
|
|
@ -586,42 +620,43 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
|
|||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t max_pos) {
|
||||
PROFILER_ZONE("ops.MulBy");
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
using V = hn::Vec<D>;
|
||||
|
||||
hn::Transform1(D(), x, max_pos, other,
|
||||
[](const auto d, const V x, const V other)
|
||||
HWY_ATTR { return hn::Mul(x, other); });
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
||||
float* HWY_RESTRICT x,
|
||||
const size_t size) {
|
||||
return MulBy(other, x, size, size);
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
||||
const size_t size, const size_t max_pos) {
|
||||
template <typename XT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
PROFILER_ZONE("ops.MulByConst");
|
||||
HWY_DASSERT(max_pos <= size);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
using V = hn::Vec<D>;
|
||||
hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR {
|
||||
return hn::Mul(x, hn::Set(d, c));
|
||||
});
|
||||
}
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
||||
float* HWY_RESTRICT x,
|
||||
const size_t size) {
|
||||
MulByConst(c, x, size, size);
|
||||
const VF v_c = hn::Set(df, c);
|
||||
const auto packed_x = MakeSpan(x, size);
|
||||
|
||||
size_t i = 0;
|
||||
if (size >= 2 * NF) {
|
||||
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||
VF x0, x1;
|
||||
Decompress2(df, packed_x, i, x0, x1);
|
||||
x0 = hn::Mul(x0, v_c);
|
||||
x1 = hn::Mul(x1, v_c);
|
||||
Compress2(df, x0, x1, packed_x, i);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = size - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||
// Ensure the second vector is zeroed even if remaining <= NF.
|
||||
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||
VF x0 = hn::Load(df, buf_x);
|
||||
VF x1 = hn::Load(df, buf_x + NF);
|
||||
x0 = hn::Mul(x0, v_c);
|
||||
x1 = hn::Mul(x1, v_c);
|
||||
Compress2(df, x0, x1, MakeSpan(buf_x, 2 * NF), 0);
|
||||
hwy::CopyBytes(buf_x, x + i, remaining * sizeof(XT));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename XT, typename OT>
|
||||
|
|
@ -656,6 +691,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
|||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
|
||||
// Ensure the second vectors are zeroed even if remaining <= NF.
|
||||
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||
hn::Store(hn::Zero(df), df, buf_out + NF);
|
||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
|
||||
const VF x0 = hn::Load(df, buf_x);
|
||||
|
|
@ -671,11 +709,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
|||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||
const size_t mask_pos,
|
||||
float temperature = 1.0f) {
|
||||
PROFILER_ZONE("ops.Softmax");
|
||||
HWY_DASSERT(size != 0);
|
||||
HWY_DASSERT(mask_pos <= size);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
|
|
@ -685,13 +721,13 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V vmax = vmin;
|
||||
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
||||
hn::Foreach(d, x, mask_pos, vmin,
|
||||
[pmax](const auto d, const V value)
|
||||
HWY_ATTR { *pmax = hn::Max(*pmax, value); });
|
||||
hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
*pmax = hn::Max(*pmax, value);
|
||||
});
|
||||
vmax = hn::MaxOfLanes(d, vmax);
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||
|
|
@ -702,7 +738,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
|
||||
if (temperature != 1.0f) {
|
||||
const float temperature_inv = 1.0f / temperature;
|
||||
hn::Transform(d, x, mask_pos,
|
||||
hn::Transform(d, x, size,
|
||||
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
||||
return hn::Mul(value, hn::Set(d, temperature_inv));
|
||||
});
|
||||
|
|
@ -712,16 +748,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
// not make a huge difference. It halves the standard deviation of the sum of
|
||||
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||
// the generated text after a few hundred tokens.
|
||||
const float sum_exp = Sum(d, x, mask_pos);
|
||||
const float sum_exp = Sum(d, x, size);
|
||||
// Double-precision reciprocal does not appear to affect the results.
|
||||
const float mul = 1.0f / sum_exp;
|
||||
MulByConst(mul, x, size, mask_pos);
|
||||
}
|
||||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
||||
const size_t size,
|
||||
float temperature = 1.0f) {
|
||||
Softmax(x, size, size, temperature);
|
||||
MulByConst(mul, x, size);
|
||||
}
|
||||
|
||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||
|
|
|
|||
Loading…
Reference in New Issue