mirror of https://github.com/google/gemma.cpp.git
[Gemma.cpp] Allows non-owned arguments for attention methods.
* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`. * Updates `QBatch` to hold non-owning `MatPtr`s to the kv caches. * Enables the `MatPtrT` default constructor for simpler initializations. * Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor. PiperOrigin-RevId: 824584177
This commit is contained in:
parent
86200ce224
commit
5a05857deb
|
|
@ -139,7 +139,6 @@ cc_test(
|
|||
":kv_cache",
|
||||
":mat",
|
||||
":matmul",
|
||||
":ops",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
|
|
|
|||
|
|
@ -31,26 +31,24 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
struct AttentionActivations {
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||
const LayerConfig& layer_config = config.layer_configs[0];
|
||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f /
|
||||
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
|
||||
}
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||
const LayerConfig& layer_config = config.layer_configs[0];
|
||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f /
|
||||
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
|
||||
}
|
||||
|
||||
struct AttentionActivations {
|
||||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: config(config),
|
||||
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||
// and does not use an external KV cache.
|
||||
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||
// MHA and does not use an external KV cache.
|
||||
q(MatFactory("q", batch_size,
|
||||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
|
|
@ -76,11 +74,7 @@ struct AttentionActivations {
|
|||
layer_config.post_qk == PostQKType::HalfRope)),
|
||||
inv_timescale_global(CreateInvTimescale(
|
||||
allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
|
||||
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
div_heads(static_cast<uint32_t>(layer_config.heads)),
|
||||
query_scale(ChooseQueryScale(config)) {
|
||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
|
||||
// Batch size can be 0 in experimental code so do not assert.
|
||||
if (batch_size == 0) {
|
||||
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
|
||||
|
|
@ -108,9 +102,7 @@ struct AttentionActivations {
|
|||
att_sums.OverrideRows(batch_size);
|
||||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<float> q_T; // Transposed to maximize attention speed.
|
||||
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
|
|
@ -122,9 +114,39 @@ struct AttentionActivations {
|
|||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
};
|
||||
|
||||
// A non-owning view of AttentionActivations.
|
||||
struct AttentionActivationsPtrs {
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
|
||||
: config(config),
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
|
||||
query_scale(ChooseQueryScale(config)) {}
|
||||
|
||||
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
|
||||
const AttentionActivations& activations)
|
||||
: AttentionActivationsPtrs(config, seq_len) {
|
||||
q = activations.q;
|
||||
q_T = activations.q_T;
|
||||
pre_att_rms_out = activations.pre_att_rms_out;
|
||||
att = activations.att;
|
||||
att_out = activations.att_out;
|
||||
att_sums = activations.att_sums;
|
||||
inv_timescale = activations.inv_timescale;
|
||||
inv_timescale_global = activations.inv_timescale_global;
|
||||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
MatPtrT<float> q;
|
||||
MatPtrT<float> q_T;
|
||||
MatPtrT<float> pre_att_rms_out;
|
||||
MatPtrT<float> att;
|
||||
MatPtrT<float> att_out;
|
||||
MatPtrT<BF16> att_sums;
|
||||
MatPtrT<float> inv_timescale;
|
||||
MatPtrT<float> inv_timescale_global;
|
||||
hwy::Divisor div_seq_len;
|
||||
// Unfortunately, some models have had non-power-of-two heads.
|
||||
hwy::Divisor div_heads;
|
||||
float query_scale;
|
||||
};
|
||||
|
|
@ -150,8 +172,9 @@ struct Activations {
|
|||
ffw_out(
|
||||
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||
|
||||
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
|
||||
row_ptrs) {
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
ctx.allocator, row_ptrs),
|
||||
attention(config, seq_len, attention_storage) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
|
|
@ -179,12 +202,12 @@ struct Activations {
|
|||
C2.OverrideRows(batch_size);
|
||||
ffw_out.OverrideRows(batch_size);
|
||||
|
||||
attention.SetBatchSize(batch_size);
|
||||
attention_storage.SetBatchSize(batch_size);
|
||||
}
|
||||
|
||||
const LayerConfig& layer_config;
|
||||
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
|
||||
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
||||
|
|
@ -195,7 +218,8 @@ struct Activations {
|
|||
MatStorageT<BF16> C2;
|
||||
MatStorageT<float> ffw_out;
|
||||
|
||||
AttentionActivations attention;
|
||||
AttentionActivations attention_storage;
|
||||
AttentionActivationsPtrs attention;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
|||
}
|
||||
|
||||
void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
ThreadingContext& ctx, const size_t worker,
|
||||
const size_t pos, const float mul) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const PostQKType& post_qk = layer_config.post_qk;
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
|
||||
|
|
@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV(
|
|||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
if (query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
|
||||
layer.layer_config.qkv_dim, ctx, worker);
|
||||
layer_config.qkv_dim, ctx, worker);
|
||||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
|
||||
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
||||
|
|
@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
|
|||
}
|
||||
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
const MatPtrT<float>& query_norm_scale,
|
||||
AttentionActivationsPtrs& activations,
|
||||
QBatch& qbatch, ThreadingContext& ctx) {
|
||||
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
|
|
@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
||||
layer, activations, att, att_out, ctx, worker);
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
||||
query_norm_scale, layer_idx, activations, att,
|
||||
att_out, ctx, worker);
|
||||
};
|
||||
|
||||
{
|
||||
|
|
@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
// Fills activations.q and writes to KV cache.
|
||||
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
AttentionActivationsPtrs& activations,
|
||||
const QBatch& qbatch, const int flags,
|
||||
MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
|
||||
|
|
@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
});
|
||||
}
|
||||
|
||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
|
||||
worker, pos, /*mul=*/1.0f);
|
||||
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
||||
pos, /*mul=*/1.0f);
|
||||
CompressPerThread tls;
|
||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||
});
|
||||
|
|
@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
AttentionActivationsPtrs& activations,
|
||||
MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
|
@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|||
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
|
||||
|
||||
|
|
@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
|
||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||
if (flags & kAttentionUseOld) {
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx);
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
|
||||
activations, qbatch, env.ctx);
|
||||
} else {
|
||||
// * 2 does not help on Turin.
|
||||
FlashAttention(num_tokens,
|
||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
||||
layer_idx, layer, activations, qbatch, env.ctx);
|
||||
layer_idx, layer.query_norm_scale, activations, qbatch,
|
||||
env.ctx);
|
||||
}
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,8 +29,7 @@ namespace gcpp {
|
|||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
ThreadingContext& ctx, size_t worker, size_t pos, \
|
||||
float mul); \
|
||||
\
|
||||
|
|
@ -39,18 +38,18 @@ namespace gcpp {
|
|||
void SingleDotSoftmaxWeightedSum( \
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, \
|
||||
const MatPtrT<float>& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, \
|
||||
QBatch& qbatch, ThreadingContext& ctx); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@
|
|||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h" // kMaxQKVDim
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -91,32 +90,33 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
|||
|
||||
// Updates q in place for RMSNorm and positional encoding.
|
||||
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||
MatPtrT<KV_t>& q, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
MatPtrT<float>& q,
|
||||
const MatPtrT<float>& query_norm_scale,
|
||||
const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
ThreadingContext& ctx) {
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const float query_scale = activations.query_scale;
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
|
||||
size_t qi = div_qbatch.Remainder(task);
|
||||
size_t batch_idx = div_qbatch.Divide(task);
|
||||
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
||||
for (size_t h = 0; h < layer_config.heads; ++h) {
|
||||
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
float* HWY_RESTRICT q_row =
|
||||
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
||||
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
if (query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
|
||||
layer.layer_config.qkv_dim, ctx, worker);
|
||||
layer_config.qkv_dim, ctx, worker);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker,
|
||||
pos, query_scale);
|
||||
PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos,
|
||||
query_scale);
|
||||
}
|
||||
};
|
||||
{
|
||||
|
|
@ -154,8 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
|
|||
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const AttentionActivations& activations,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||
const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
||||
|
|
@ -265,15 +264,17 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
|||
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
|
||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
||||
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
|
||||
const size_t min_last_pos, const size_t max_last_pos,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
void TileFlashAttention(const MatPtrT<float>& q,
|
||||
const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
||||
const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos,
|
||||
const size_t min_last_pos, const size_t max_last_pos,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
MatPtrT<float>& att_out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
|
||||
constexpr int kHTileSize = kNFx8HTileSize;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
|
|
@ -419,14 +420,16 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
|||
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
|
||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||
// max_last_pos].
|
||||
void TileFlashAttention4(
|
||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
void TileFlashAttention4(const MatPtrT<float>& q,
|
||||
const uint32_t* HWY_RESTRICT q_offsets,
|
||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||
const uint32_t* HWY_RESTRICT last_pos,
|
||||
const size_t min_last_pos, const size_t max_last_pos,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations,
|
||||
MatPtrT<float>& att_out,
|
||||
const uint32_t* HWY_RESTRICT out_offsets,
|
||||
ThreadingContext& ctx, const size_t worker) {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
|
|
@ -589,14 +592,15 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
|
|||
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
||||
// 3 modes to use for best efficiency.
|
||||
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
const size_t layer_idx,
|
||||
const MatPtrT<float>& query_norm_scale,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
|
||||
layer, activations, ctx);
|
||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
||||
query_norm_scale, layer_idx, activations, ctx);
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
|
|
@ -732,12 +736,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
// is used above to catch all cases where qT will be used.
|
||||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
max_last_pos, v, layer_idx, activations,
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
} else if (kVTileSize == 4) {
|
||||
TileFlashAttention4(activations.q, q_offsets, k,
|
||||
start_positions[offset], last_pos, min_last_pos,
|
||||
max_last_pos, v, layer_idx, layer, activations,
|
||||
max_last_pos, v, layer_idx, activations,
|
||||
activations.att_out, out_offsets, ctx, worker);
|
||||
} else {
|
||||
HWY_UNREACHABLE;
|
||||
|
|
@ -746,7 +750,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
} else {
|
||||
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
||||
activations.q.Row(0) + q_offsets[offset], k, v,
|
||||
layer_idx, layer, activations,
|
||||
layer_idx, activations,
|
||||
activations.att_out.Row(0) + out_offsets[offset],
|
||||
ctx, worker);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,17 +28,16 @@ namespace gcpp {
|
|||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \
|
||||
MatPtrT<KV_t>& q, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
ThreadingContext& ctx); \
|
||||
void RMSNormAndPositionalEncoding( \
|
||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
const float* HWY_RESTRICT q, \
|
||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const AttentionActivations& activations, \
|
||||
size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
|
|
@ -46,8 +45,9 @@ namespace gcpp {
|
|||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
size_t layer_idx, \
|
||||
const MatPtrT<float>& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
|
|
|||
|
|
@ -122,8 +122,9 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
|
||||
const size_t batch_size = kOuter;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
AttentionActivations attention(config, layer_config, batch_size, kOuter,
|
||||
ctx.allocator, row_ptrs);
|
||||
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||
kOuter, ctx.allocator, row_ptrs);
|
||||
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
ASSERT_EQ(qkv_dim, kInner);
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
|
|
@ -145,7 +146,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
SetMat(h + layer_config.heads * 2, v);
|
||||
}
|
||||
SetMat(1, attention.q);
|
||||
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx);
|
||||
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
|
||||
qbatch, ctx);
|
||||
// Copy the output to saved_att to allow for comparison.
|
||||
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
||||
SetMat(1, attention.q);
|
||||
|
|
@ -158,8 +160,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
total_tasks, target_parallelism);
|
||||
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
|
||||
target_parallelism, kNF, kVTileSize);
|
||||
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
|
||||
qbatch, ctx);
|
||||
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
|
||||
attention, qbatch, ctx);
|
||||
AssertClose(attention.att_out, *saved_att);
|
||||
ctx.profiler.PrintResults();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ struct PerQuery {
|
|||
// attention in Paligemma.
|
||||
size_t prefix_end;
|
||||
|
||||
KVCache& kv_cache;
|
||||
KVCachePtr kv_cache;
|
||||
|
||||
// Previous token generated for this query, or the last prompt token. Will be
|
||||
// fed into the next Transformer() call.
|
||||
|
|
@ -64,7 +64,7 @@ struct AllQueries {
|
|||
|
||||
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const hwy::Span<KVCache>& kv_caches) {
|
||||
const hwy::Span<KVCachePtr>& kv_caches) {
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
|
|
@ -78,11 +78,16 @@ struct AllQueries {
|
|||
}
|
||||
}
|
||||
|
||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const hwy::Span<KVCache>& kv_caches)
|
||||
: AllQueries(prompt, pos, prefix_end,
|
||||
hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {}
|
||||
|
||||
// Batch of queries with initial position set to zero. Causal attention
|
||||
// is requested via empty or all-zero `prefix_end`.
|
||||
AllQueries(
|
||||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCache>& kv_caches,
|
||||
const hwy::Span<KVCachePtr>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||
|
|
@ -99,6 +104,13 @@ struct AllQueries {
|
|||
}
|
||||
}
|
||||
|
||||
AllQueries(
|
||||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCache>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>())
|
||||
: AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)),
|
||||
prefix_end) {}
|
||||
|
||||
void Reserve(size_t size) { per_query_.reserve(size); }
|
||||
void Append(const PerQuery& query) { per_query_.push_back(query); }
|
||||
|
||||
|
|
@ -156,7 +168,7 @@ class QBatch {
|
|||
size_t PrefixEnd(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prefix_end;
|
||||
}
|
||||
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
|
|
@ -54,4 +55,13 @@ KVCache KVCache::Copy() {
|
|||
return copy;
|
||||
}
|
||||
|
||||
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
|
||||
std::vector<KVCachePtr> ptrs;
|
||||
ptrs.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache});
|
||||
}
|
||||
return ptrs;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // InferenceArgs
|
||||
|
|
@ -46,6 +47,15 @@ struct KVCache {
|
|||
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
||||
};
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
struct KVCachePtr {
|
||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||
MatPtrT<KV_t> kv_cache;
|
||||
};
|
||||
|
||||
// Convenience function to create views into KVCaches.
|
||||
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||
|
|
|
|||
|
|
@ -454,7 +454,7 @@ void TestRopeAndMulBy() {
|
|||
x.Row(0)[i] = random_float();
|
||||
}
|
||||
|
||||
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
||||
const float qmul = ChooseQueryScale(config);
|
||||
constexpr float kmul = 1.0f;
|
||||
|
||||
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);
|
||||
|
|
|
|||
|
|
@ -284,6 +284,9 @@ class MatPtrT : public MatPtr {
|
|||
public:
|
||||
using T = MatT;
|
||||
|
||||
// Default constructor for use with uninitialized views.
|
||||
MatPtrT() = default;
|
||||
|
||||
// Called by `MatStorageT`.
|
||||
MatPtrT(const char* name, Extents2D extents)
|
||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue