mirror of https://github.com/google/gemma.cpp.git
1.02x batch decode speedup: BF16 KV cache
ops-inl.h: Vectorize Rope(), template Remove unused MulBy, and extra-arg overloads of MulByConst and Softmax Fix for DecompressAndZeroPad: ensure second vector filled PiperOrigin-RevId: 772779163
This commit is contained in:
parent
606e22155a
commit
343482c7ef
|
|
@ -449,6 +449,7 @@ cc_library(
|
||||||
srcs = ["gemma/kv_cache.cc"],
|
srcs = ["gemma/kv_cache.cc"],
|
||||||
hdrs = ["gemma/kv_cache.h"],
|
hdrs = ["gemma/kv_cache.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":basics",
|
||||||
":configs",
|
":configs",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":mat",
|
":mat",
|
||||||
|
|
@ -504,6 +505,7 @@ cc_library(
|
||||||
":threading",
|
":threading",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":weights",
|
":weights",
|
||||||
|
"//compression:compress",
|
||||||
"//compression:types",
|
"//compression:types",
|
||||||
"//io:blob_store",
|
"//io:blob_store",
|
||||||
"//io",
|
"//io",
|
||||||
|
|
|
||||||
|
|
@ -517,8 +517,19 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and
|
// Same as above, but without parallelization nor benchmarking.
|
||||||
// RMSNormInplace for the two output types.
|
template <typename Packed>
|
||||||
|
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||||
|
CompressPerThread& tls,
|
||||||
|
const PackedSpan<Packed>& packed,
|
||||||
|
const size_t packed_ofs) {
|
||||||
|
packed.BoundsCheck(packed_ofs, num);
|
||||||
|
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
Traits::Compress(df, raw, num, tls, packed, packed_ofs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stores two f32 vectors to f32 or bf16.
|
||||||
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||||
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||||
const size_t packed_ofs) {
|
const size_t packed_ofs) {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#endif // HWY_DISABLED_TARGETS
|
#endif // HWY_DISABLED_TARGETS
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/configs.h" // kMaxQKVDim
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
|
@ -39,6 +40,7 @@
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
@ -50,7 +52,7 @@ namespace HWY_NAMESPACE {
|
||||||
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
const hwy::Divisor& div_seq_len,
|
const hwy::Divisor& div_seq_len,
|
||||||
const float* HWY_RESTRICT q,
|
const float* HWY_RESTRICT q,
|
||||||
const MatPtrT<float>& k, float* HWY_RESTRICT att) {
|
const MatPtrT<BF16>& k, float* HWY_RESTRICT att) {
|
||||||
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
if (HWY_LIKELY(last_pos < static_cast<size_t>(div_seq_len.GetDivisor()))) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
|
|
@ -66,8 +68,7 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U>
|
static void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
static void PositionalEncodingQK(U* qk, const size_t layer_idx,
|
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
const AttentionActivations& activations,
|
const AttentionActivations& activations,
|
||||||
const size_t pos, const float mul = 1.0f) {
|
const size_t pos, const float mul = 1.0f) {
|
||||||
|
|
@ -97,7 +98,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||||
const size_t last_pos,
|
const size_t last_pos,
|
||||||
const hwy::Divisor& div_seq_len,
|
const hwy::Divisor& div_seq_len,
|
||||||
const float* HWY_RESTRICT att,
|
const float* HWY_RESTRICT att,
|
||||||
const MatPtrT<float>& v,
|
const MatPtrT<BF16>& v,
|
||||||
float* HWY_RESTRICT att_out) {
|
float* HWY_RESTRICT att_out) {
|
||||||
const size_t qkv_dim = v.Cols();
|
const size_t qkv_dim = v.Cols();
|
||||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||||
|
|
@ -110,7 +111,7 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||||
} else {
|
} else {
|
||||||
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
for (size_t pos = start_pos; pos <= last_pos; ++pos) {
|
||||||
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
const size_t pos_modulo = div_seq_len.Remainder(pos);
|
||||||
const float* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
|
const BF16* HWY_RESTRICT v_ptr = v.Row(pos_modulo);
|
||||||
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
|
MulByConstAndAdd(att[pos_modulo], v_ptr, att_out, v.Cols());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -118,11 +119,13 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||||
|
|
||||||
// Calculates the attention outputs for a single q, which may be updated
|
// Calculates the attention outputs for a single q, which may be updated
|
||||||
// in place for RMSNorm.
|
// in place for RMSNorm.
|
||||||
void SingleDotSoftmaxWeightedSum(
|
void SingleDotSoftmaxWeightedSum(const size_t pos, const size_t start_pos,
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
const size_t last_pos, float* HWY_RESTRICT q,
|
||||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
const MatPtrT<BF16>& k, const MatPtrT<BF16>& v,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const size_t layer_idx,
|
||||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
const LayerWeightsPtrs& layer,
|
||||||
|
const AttentionActivations& activations,
|
||||||
|
float* HWY_RESTRICT att,
|
||||||
float* HWY_RESTRICT att_out) {
|
float* HWY_RESTRICT att_out) {
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
|
|
@ -136,6 +139,7 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
layer.layer_config.qkv_dim);
|
layer.layer_config.qkv_dim);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
|
PositionalEncodingQK(q, layer_idx, layer, activations, pos, query_scale);
|
||||||
|
|
||||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
|
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
|
||||||
|
|
@ -220,10 +224,10 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
// this query and head.
|
// this query and head.
|
||||||
const size_t kv_head_offset =
|
const size_t kv_head_offset =
|
||||||
layer_idx * cache_layer_size + head_offset;
|
layer_idx * cache_layer_size + head_offset;
|
||||||
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
MatPtrT<BF16> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
|
k.SetPtr(kv_cache.Row(0) + kv_head_offset,
|
||||||
kv_cache.Stride());
|
kv_cache.Stride());
|
||||||
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
MatPtrT<BF16> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
|
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim,
|
||||||
kv_cache.Stride());
|
kv_cache.Stride());
|
||||||
|
|
||||||
|
|
@ -263,7 +267,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
// Set up MatMul row pointers for writing to KV, which consists of
|
// Set up MatMul row pointers for writing to KV, which consists of
|
||||||
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
|
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
|
||||||
// because rows are computed modulo seq_len.
|
// because rows are computed modulo seq_len.
|
||||||
MatPtrT<float> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
|
MatPtrT<BF16> kv_rows("kv", Extents2D(activations.pre_att_rms_out.Rows(),
|
||||||
layer.qkv_einsum_w2.Rows()));
|
layer.qkv_einsum_w2.Rows()));
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
|
|
@ -291,7 +295,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||||
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
BF16* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||||
layer_idx * cache_layer_size +
|
layer_idx * cache_layer_size +
|
||||||
head * qkv_dim * 2;
|
head * qkv_dim * 2;
|
||||||
|
|
||||||
|
|
@ -302,7 +306,12 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(kv, layer_idx, layer, activations, pos);
|
HWY_ALIGN float kv_f32[kMaxQKVDim];
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
DecompressAndZeroPad(df, MakeSpan(kv, qkv_dim), 0, kv_f32, qkv_dim);
|
||||||
|
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, pos);
|
||||||
|
CompressPerThread tls;
|
||||||
|
Compress(kv_f32, qkv_dim, tls, MakeSpan(kv, qkv_dim), 0);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ namespace gcpp {
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
void SingleDotSoftmaxWeightedSum( \
|
void SingleDotSoftmaxWeightedSum( \
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
|
float* HWY_RESTRICT q, const MatPtrT<BF16>& k, const MatPtrT<BF16>& v, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||||
float* HWY_RESTRICT att_out); \
|
float* HWY_RESTRICT att_out); \
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static constexpr size_t kMaxConv1DWidth = 4;
|
static constexpr size_t kMaxConv1DWidth = 4;
|
||||||
|
static constexpr size_t kMaxQKVDim = 1024;
|
||||||
|
|
||||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||||
enum class PromptWrapping {
|
enum class PromptWrapping {
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,8 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h" // InferenceArgs
|
||||||
|
#include "util/basics.h" // BF16
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -41,7 +42,7 @@ struct KVCache {
|
||||||
MatStorageT<float> conv1d_cache;
|
MatStorageT<float> conv1d_cache;
|
||||||
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
||||||
|
|
||||||
MatStorageT<float> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
MatStorageT<BF16> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// For use by other ctor and Copy()
|
// For use by other ctor and Copy()
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "compression/types.h"
|
#include "compression/types.h"
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig, kMaxQKVDim
|
||||||
#include "gemma/tensor_info.h"
|
#include "gemma/tensor_info.h"
|
||||||
#include "gemma/tokenizer.h"
|
#include "gemma/tokenizer.h"
|
||||||
#include "io/blob_store.h"
|
#include "io/blob_store.h"
|
||||||
|
|
@ -234,6 +234,11 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
|
||||||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||||
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
|
||||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||||
|
for (const LayerConfig& layer_config : config.layer_configs) {
|
||||||
|
if (static_cast<size_t>(layer_config.qkv_dim) > kMaxQKVDim) {
|
||||||
|
HWY_ABORT("Increase kMaxQKVDim to at least %u.", layer_config.qkv_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// We trust the deserialized config, but checking helps to validate the
|
// We trust the deserialized config, but checking helps to validate the
|
||||||
// deduction, which we rely on below for pre-2025 files.
|
// deduction, which we rely on below for pre-2025 files.
|
||||||
|
|
|
||||||
252
ops/ops-inl.h
252
ops/ops-inl.h
|
|
@ -341,6 +341,10 @@ HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
|
||||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||||
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
|
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
|
||||||
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
|
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
|
||||||
|
// Ensure the second vectors are zeroed even if remaining <= NF.
|
||||||
|
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||||
|
hn::Store(hn::Zero(df), df, buf_scale + NF);
|
||||||
|
hn::Store(hn::Zero(df), df, buf_bias + NF);
|
||||||
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
|
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
|
||||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||||
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
|
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
|
||||||
|
|
@ -399,94 +403,122 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
of this rotation matrix which is simply the same matrix with -pos parameter)
|
of this rotation matrix which is simply the same matrix with -pos parameter)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||||
// This overload is called if kUseHalfRope.
|
// This overload is called if `post_qk == PostQKType::HalfRope`.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
float* HWY_RESTRICT x, size_t dim_qkv,
|
float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||||
PROFILER_ZONE("ops.Rope");
|
PROFILER_ZONE("ops.Rope");
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
|
||||||
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
|
|
||||||
const float cos_val = cosf(theta);
|
|
||||||
const float sin_val = sinf(theta);
|
|
||||||
const float x0 = x[dim];
|
|
||||||
const float x1 = x[dim + half_dim_qkv];
|
|
||||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
|
||||||
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
const hn::ScalableTag<float> df;
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
const size_t NF = hn::Lanes(df);
|
||||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
using VF = hn::Vec<decltype(df)>;
|
||||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
const VF vpos = hn::Set(df, static_cast<float>(pos));
|
||||||
PROFILER_ZONE("ops.RopeAndMulBy");
|
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
|
||||||
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
const D d;
|
|
||||||
|
|
||||||
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
||||||
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
|
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
|
||||||
size_t dim = 0;
|
size_t dim = 0;
|
||||||
for (; dim < vectorizable_dims; dim += hn::Lanes(d)) {
|
for (; dim < vectorizable_dims; dim += NF) {
|
||||||
// Compute thetas
|
const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
|
||||||
V pos_vec = hn::Set(d, pos);
|
const VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||||
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
|
||||||
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
|
|
||||||
|
|
||||||
// Compute rotations.
|
// Compute rotations.
|
||||||
V cos_theta_vec;
|
VF vcos_theta;
|
||||||
V sin_theta_vec;
|
VF vsin_theta;
|
||||||
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
// Scale input with rotations and multiply with constant.
|
// Scale input with rotations.
|
||||||
V mul_vec = hn::Set(d, mul);
|
VF vx0 = hn::LoadU(df, x + dim);
|
||||||
V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim));
|
VF vx1 = hn::LoadU(df, x + dim + half_dim_qkv);
|
||||||
V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv));
|
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
|
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec,
|
hn::StoreU(vx0, df, x + dim);
|
||||||
hn::Mul(x1_vec, sin_theta_vec));
|
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
||||||
V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec,
|
|
||||||
hn::Mul(x1_vec, cos_theta_vec));
|
|
||||||
|
|
||||||
// Store
|
|
||||||
hn::StoreU(xout_0_vec, d, x + dim);
|
|
||||||
hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||||
const size_t remaining_dims = half_dim_qkv - dim;
|
const size_t remaining_dims = half_dim_qkv - dim;
|
||||||
HWY_DASSERT(remaining_dims < hn::Lanes(d)); // at most one iteration
|
HWY_DASSERT(remaining_dims < NF); // at most one iteration
|
||||||
if (remaining_dims != 0) {
|
if (remaining_dims != 0) {
|
||||||
// Compute thetas
|
VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
|
||||||
V pos_vec = hn::Set(d, pos);
|
VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||||
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
|
|
||||||
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
|
|
||||||
|
|
||||||
// Compute rotations.
|
// Compute rotations.
|
||||||
V cos_theta_vec;
|
VF vcos_theta;
|
||||||
V sin_theta_vec;
|
VF vsin_theta;
|
||||||
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
|
// Scale input with rotations.
|
||||||
|
VF vx0 = hn::LoadN(df, x + dim, remaining_dims);
|
||||||
|
VF vx1 = hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims);
|
||||||
|
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
|
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
|
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
||||||
|
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `inv_timescale[dim_qkv / 2]` is precomputed in AttentionActivations.
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
|
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
|
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||||
|
PROFILER_ZONE("ops.RopeAndMulBy");
|
||||||
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
const size_t NF = hn::Lanes(df);
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
const VF vmul = hn::Set(df, mul);
|
||||||
|
const VF vpos = hn::Set(df, static_cast<float>(pos));
|
||||||
|
|
||||||
|
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
||||||
|
const size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, NF);
|
||||||
|
size_t dim = 0;
|
||||||
|
for (; dim < vectorizable_dims; dim += NF) {
|
||||||
|
const VF vinv_time_scale = hn::LoadU(df, inv_timescale + dim);
|
||||||
|
const VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||||
|
|
||||||
|
// Compute rotations.
|
||||||
|
VF vcos_theta;
|
||||||
|
VF vsin_theta;
|
||||||
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
// Scale input with rotations and multiply with constant.
|
// Scale input with rotations and multiply with constant.
|
||||||
V mul_vec = hn::Set(d, mul);
|
VF vx0 = hn::Mul(vmul, hn::LoadU(df, x + dim));
|
||||||
V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims));
|
VF vx1 = hn::Mul(vmul, hn::LoadU(df, x + dim + half_dim_qkv));
|
||||||
V x1_vec =
|
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims));
|
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
V xout_0_vec =
|
hn::StoreU(vx0, df, x + dim);
|
||||||
hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec));
|
hn::StoreU(vx1, df, x + dim + half_dim_qkv);
|
||||||
V xout_1_vec =
|
}
|
||||||
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
|
|
||||||
|
|
||||||
// Store
|
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||||
hn::StoreN(xout_0_vec, d, x + dim, remaining_dims);
|
const size_t remaining_dims = half_dim_qkv - dim;
|
||||||
hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims);
|
HWY_DASSERT(remaining_dims < NF); // at most one iteration
|
||||||
|
if (remaining_dims != 0) {
|
||||||
|
VF vinv_time_scale = hn::LoadN(df, inv_timescale + dim, remaining_dims);
|
||||||
|
VF vtheta = hn::Mul(vpos, vinv_time_scale);
|
||||||
|
|
||||||
|
// Compute rotations.
|
||||||
|
VF vcos_theta;
|
||||||
|
VF vsin_theta;
|
||||||
|
hn::SinCos(df, vtheta, vsin_theta, vcos_theta);
|
||||||
|
|
||||||
|
// Scale input with rotations and multiply with constant.
|
||||||
|
VF vx0 = hn::Mul(vmul, hn::LoadN(df, x + dim, remaining_dims));
|
||||||
|
VF vx1 =
|
||||||
|
hn::Mul(vmul, hn::LoadN(df, x + dim + half_dim_qkv, remaining_dims));
|
||||||
|
vx0 = hn::MulSub(vx0, vcos_theta, hn::Mul(vx1, vsin_theta));
|
||||||
|
vx1 = hn::MulAdd(vx0, vsin_theta, hn::Mul(vx1, vcos_theta));
|
||||||
|
|
||||||
|
hn::StoreN(vx0, df, x + dim, remaining_dims);
|
||||||
|
hn::StoreN(vx1, df, x + dim + half_dim_qkv, remaining_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -521,6 +553,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||||
HWY_DASSERT(remaining1 < NF);
|
HWY_DASSERT(remaining1 < NF);
|
||||||
if (HWY_UNLIKELY(remaining != 0)) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||||
|
// Ensure the second vector is zeroed even if remaining <= NF.
|
||||||
|
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||||
const VF x0 = hn::Load(df, buf_x);
|
const VF x0 = hn::Load(df, buf_x);
|
||||||
const VF x1 = hn::Load(df, buf_x + NF);
|
const VF x1 = hn::Load(df, buf_x + NF);
|
||||||
|
|
@ -586,42 +620,43 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
template <typename XT>
|
||||||
float* HWY_RESTRICT x, const size_t size,
|
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(float c, XT* HWY_RESTRICT x,
|
||||||
const size_t max_pos) {
|
size_t size) {
|
||||||
PROFILER_ZONE("ops.MulBy");
|
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
|
|
||||||
hn::Transform1(D(), x, max_pos, other,
|
|
||||||
[](const auto d, const V x, const V other)
|
|
||||||
HWY_ATTR { return hn::Mul(x, other); });
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
|
||||||
float* HWY_RESTRICT x,
|
|
||||||
const size_t size) {
|
|
||||||
return MulBy(other, x, size, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
|
||||||
const size_t size, const size_t max_pos) {
|
|
||||||
PROFILER_ZONE("ops.MulByConst");
|
PROFILER_ZONE("ops.MulByConst");
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
const hn::ScalableTag<float> df;
|
||||||
using V = hn::Vec<D>;
|
const size_t NF = hn::Lanes(df);
|
||||||
hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR {
|
using VF = hn::Vec<decltype(df)>;
|
||||||
return hn::Mul(x, hn::Set(d, c));
|
|
||||||
});
|
const VF v_c = hn::Set(df, c);
|
||||||
|
const auto packed_x = MakeSpan(x, size);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
if (size >= 2 * NF) {
|
||||||
|
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||||
|
VF x0, x1;
|
||||||
|
Decompress2(df, packed_x, i, x0, x1);
|
||||||
|
x0 = hn::Mul(x0, v_c);
|
||||||
|
x1 = hn::Mul(x1, v_c);
|
||||||
|
Compress2(df, x0, x1, packed_x, i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
const size_t remaining = size - i;
|
||||||
float* HWY_RESTRICT x,
|
HWY_DASSERT(remaining < 2 * NF);
|
||||||
const size_t size) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
MulByConst(c, x, size, size);
|
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||||
|
// Ensure the second vector is zeroed even if remaining <= NF.
|
||||||
|
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||||
|
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||||
|
VF x0 = hn::Load(df, buf_x);
|
||||||
|
VF x1 = hn::Load(df, buf_x + NF);
|
||||||
|
x0 = hn::Mul(x0, v_c);
|
||||||
|
x1 = hn::Mul(x1, v_c);
|
||||||
|
Compress2(df, x0, x1, MakeSpan(buf_x, 2 * NF), 0);
|
||||||
|
hwy::CopyBytes(buf_x, x + i, remaining * sizeof(XT));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename XT, typename OT>
|
template <typename XT, typename OT>
|
||||||
|
|
@ -656,6 +691,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
||||||
if (HWY_UNLIKELY(remaining != 0)) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
|
||||||
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
|
HWY_ALIGN float buf_out[2 * hn::MaxLanes(df)];
|
||||||
|
// Ensure the second vectors are zeroed even if remaining <= NF.
|
||||||
|
hn::Store(hn::Zero(df), df, buf_x + NF);
|
||||||
|
hn::Store(hn::Zero(df), df, buf_out + NF);
|
||||||
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
|
||||||
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
|
DecompressAndZeroPad(df, packed_out, i, buf_out, remaining);
|
||||||
const VF x0 = hn::Load(df, buf_x);
|
const VF x0 = hn::Load(df, buf_x);
|
||||||
|
|
@ -671,11 +709,9 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(float c,
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t mask_pos,
|
|
||||||
float temperature = 1.0f) {
|
float temperature = 1.0f) {
|
||||||
PROFILER_ZONE("ops.Softmax");
|
PROFILER_ZONE("ops.Softmax");
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -685,13 +721,13 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||||
V vmax = vmin;
|
V vmax = vmin;
|
||||||
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
||||||
hn::Foreach(d, x, mask_pos, vmin,
|
hn::Foreach(d, x, size, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
[pmax](const auto d, const V value)
|
*pmax = hn::Max(*pmax, value);
|
||||||
HWY_ATTR { *pmax = hn::Max(*pmax, value); });
|
});
|
||||||
vmax = hn::MaxOfLanes(d, vmax);
|
vmax = hn::MaxOfLanes(d, vmax);
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
|
hn::Transform(d, x, size, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||||
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
||||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||||
|
|
@ -702,7 +738,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
|
|
||||||
if (temperature != 1.0f) {
|
if (temperature != 1.0f) {
|
||||||
const float temperature_inv = 1.0f / temperature;
|
const float temperature_inv = 1.0f / temperature;
|
||||||
hn::Transform(d, x, mask_pos,
|
hn::Transform(d, x, size,
|
||||||
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
[temperature_inv](const auto d, const V value) HWY_ATTR {
|
||||||
return hn::Mul(value, hn::Set(d, temperature_inv));
|
return hn::Mul(value, hn::Set(d, temperature_inv));
|
||||||
});
|
});
|
||||||
|
|
@ -712,16 +748,10 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
// not make a huge difference. It halves the standard deviation of the sum of
|
// not make a huge difference. It halves the standard deviation of the sum of
|
||||||
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||||
// the generated text after a few hundred tokens.
|
// the generated text after a few hundred tokens.
|
||||||
const float sum_exp = Sum(d, x, mask_pos);
|
const float sum_exp = Sum(d, x, size);
|
||||||
// Double-precision reciprocal does not appear to affect the results.
|
// Double-precision reciprocal does not appear to affect the results.
|
||||||
const float mul = 1.0f / sum_exp;
|
const float mul = 1.0f / sum_exp;
|
||||||
MulByConst(mul, x, size, mask_pos);
|
MulByConst(mul, x, size);
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
|
||||||
const size_t size,
|
|
||||||
float temperature = 1.0f) {
|
|
||||||
Softmax(x, size, size, temperature);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue