[Gemma.cpp] Allows non-owned arguments for attention methods.

* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`.
* Updates `QBatch` to hold  non-owning `MatPtr`s to the kv caches.
* Enables the `MatPtrT` default constructor for simpler initializations.
* Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor.

PiperOrigin-RevId: 824584177
This commit is contained in:
Biruk Mammo 2025-10-27 10:42:46 -07:00 committed by Copybara-Service
parent 86200ce224
commit 5a05857deb
12 changed files with 185 additions and 120 deletions

View File

@ -139,7 +139,6 @@ cc_test(
":kv_cache", ":kv_cache",
":mat", ":mat",
":matmul", ":matmul",
":ops",
":threading_context", ":threading_context",
":weights", ":weights",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep

View File

@ -31,26 +31,24 @@
namespace gcpp { namespace gcpp {
struct AttentionActivations { // Returns the scale value to use for the query in the attention computation.
// Returns the scale value to use for the query in the attention computation. // Also called by ops_test.
// Also called by ops_test. static inline float ChooseQueryScale(const ModelConfig& config) {
static inline float ChooseQueryScale(const ModelConfig& config) {
const LayerConfig& layer_config = config.layer_configs[0]; const LayerConfig& layer_config = config.layer_configs[0];
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / return 1.0f /
sqrtf(static_cast<float>(config.model_dim / layer_config.heads)); sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
// QueryScaleType::SqrtKeySize // QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim)); return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
} }
struct AttentionActivations {
AttentionActivations( AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config, const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const Allocator& allocator, size_t batch_size, size_t seq_len, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: config(config), : // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache.
q(MatFactory("q", batch_size, q(MatFactory("q", batch_size,
config.vocab_size == 0 config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim ? layer_config.heads * 3 * layer_config.qkv_dim
@ -76,11 +74,7 @@ struct AttentionActivations {
layer_config.post_qk == PostQKType::HalfRope)), layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale( inv_timescale_global(CreateInvTimescale(
allocator, layer_config.qkv_dim, allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)),
query_scale(ChooseQueryScale(config)) {
// Batch size can be 0 in experimental code so do not assert. // Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) { if (batch_size == 0) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT; static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@ -108,8 +102,6 @@ struct AttentionActivations {
att_sums.OverrideRows(batch_size); att_sums.OverrideRows(batch_size);
} }
const ModelConfig& config;
MatStorageT<float> q; // query MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed. MatStorageT<float> q_T; // Transposed to maximize attention speed.
@ -122,9 +114,39 @@ struct AttentionActivations {
// Rope // Rope
MatStorageT<float> inv_timescale; MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global; MatStorageT<float> inv_timescale_global;
};
// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
: config(config),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
q = activations.q;
q_T = activations.q_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global;
}
const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;
MatPtrT<float> att_out;
MatPtrT<BF16> att_sums;
MatPtrT<float> inv_timescale;
MatPtrT<float> inv_timescale_global;
hwy::Divisor div_seq_len; hwy::Divisor div_seq_len;
// Unfortunately, some models have had non-power-of-two heads.
hwy::Divisor div_heads; hwy::Divisor div_heads;
float query_scale; float query_scale;
}; };
@ -150,8 +172,9 @@ struct Activations {
ffw_out( ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
attention(config, layer_config, batch_size, seq_len, ctx.allocator, attention_storage(config, layer_config, batch_size, seq_len,
row_ptrs) { ctx.allocator, row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers. // For MatMul outputs, precompute their row pointers.
@ -179,7 +202,7 @@ struct Activations {
C2.OverrideRows(batch_size); C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size);
attention.SetBatchSize(batch_size); attention_storage.SetBatchSize(batch_size);
} }
const LayerConfig& layer_config; const LayerConfig& layer_config;
@ -195,7 +218,8 @@ struct Activations {
MatStorageT<BF16> C2; MatStorageT<BF16> C2;
MatStorageT<float> ffw_out; MatStorageT<float> ffw_out;
AttentionActivations attention; AttentionActivations attention_storage;
AttentionActivationsPtrs attention;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
} }
void PositionalEncodingQK(float* qk, const size_t layer_idx, void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivationsPtrs& activations,
const AttentionActivations& activations,
ThreadingContext& ctx, const size_t worker, ThreadingContext& ctx, const size_t worker,
const size_t pos, const float mul) { const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const PostQKType& post_qk = layer.layer_config.post_qk; const size_t qkv_dim = layer_config.qkv_dim;
const PostQKType& post_qk = layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on. // qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1(); const float* inv_timescale = activations.inv_timescale.PackedScale1();
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx); const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV(
void SingleDotSoftmaxWeightedSum( void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer, const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
const AttentionActivations& activations, float* HWY_RESTRICT att, const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) { float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const size_t seq_len = const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor()); static_cast<size_t>(activations.div_seq_len.GetDivisor());
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) { if (query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, ctx, worker); layer_config.qkv_dim, ctx, worker);
}); });
} }
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos, PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
query_scale); query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker); QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
} }
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const MatPtrT<float>& query_norm_scale,
AttentionActivations& activations, QBatch& qbatch, AttentionActivationsPtrs& activations,
ThreadingContext& ctx) { QBatch& qbatch, ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive); GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query // A "head group" in the context of GQA refers to a collection of query
@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim)); MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx, SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
layer, activations, att, att_out, ctx, worker); query_norm_scale, layer_idx, activations, att,
att_out, ctx, worker);
}; };
{ {
@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// Fills activations.q and writes to KV cache. // Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, AttentionActivationsPtrs& activations,
const QBatch& qbatch, const int flags, const QBatch& qbatch, const int flags,
MatMulEnv& env) { MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
}); });
} }
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx, PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
worker, pos, /*mul=*/1.0f); pos, /*mul=*/1.0f);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });
@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`). // head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations, AttentionActivationsPtrs& activations,
MatMulEnv& env) { MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
void GemmaAttention(size_t num_tokens, const size_t layer_idx, void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivationsPtrs& activations, QBatch& qbatch,
MatMulEnv& env, int flags) { MatMulEnv& env, int flags) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention); GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
if (flags & kAttentionUseOld) { if (flags & kAttentionUseOld) {
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
env.ctx); activations, qbatch, env.ctx);
} else { } else {
// * 2 does not help on Turin. // * 2 does not help on Turin.
FlashAttention(num_tokens, FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
layer_idx, layer, activations, qbatch, env.ctx); layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx);
} }
SumHeads(layer, activations, env); SumHeads(layer, activations, env);
} }

View File

@ -29,8 +29,7 @@ namespace gcpp {
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \ #define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \ void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const AttentionActivationsPtrs& activations, \
const AttentionActivations& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \ ThreadingContext& ctx, size_t worker, size_t pos, \
float mul); \ float mul); \
\ \
@ -39,18 +38,18 @@ namespace gcpp {
void SingleDotSoftmaxWeightedSum( \ void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \ const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \ float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const MatPtrT<float>& query_norm_scale, \
AttentionActivations& activations, \ AttentionActivationsPtrs& activations, \
QBatch& qbatch, ThreadingContext& ctx); \ QBatch& qbatch, ThreadingContext& ctx); \
\ \
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \ MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -30,7 +30,6 @@
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim #include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -91,32 +90,33 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
// Updates q in place for RMSNorm and positional encoding. // Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<KV_t>& q, const size_t layer_idx, MatPtrT<float>& q,
const LayerWeightsPtrs& layer, const MatPtrT<float>& query_norm_scale,
const AttentionActivations& activations, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) { ThreadingContext& ctx) {
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
size_t qi = div_qbatch.Remainder(task); size_t qi = div_qbatch.Remainder(task);
size_t batch_idx = div_qbatch.Divide(task); size_t batch_idx = div_qbatch.Divide(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) { for (size_t h = 0; h < layer_config.heads; ++h) {
const size_t tq_idx = qbatch.Size() * batch_idx + qi; const size_t tq_idx = qbatch.Size() * batch_idx + qi;
// Find the token position in the query and calculate // Find the token position in the query and calculate
// the range of cache positions to attend to. // the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT q_row = float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) { if (query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
layer.layer_config.qkv_dim, ctx, worker); layer_config.qkv_dim, ctx, worker);
}); });
} }
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker, PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos,
pos, query_scale); query_scale);
} }
}; };
{ {
@ -154,8 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
void SingleFlashAttention(const size_t start_pos, const size_t last_pos, void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivationsPtrs& activations,
const AttentionActivations& activations,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) { const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
@ -265,14 +264,16 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to // Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos]. // max_last_pos].
void TileFlashAttention( void TileFlashAttention(const MatPtrT<float>& q,
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k, const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos, const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) { ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
constexpr int kHTileSize = kNFx8HTileSize; constexpr int kHTileSize = kNFx8HTileSize;
@ -419,13 +420,15 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to // Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos, // min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos]. // max_last_pos].
void TileFlashAttention4( void TileFlashAttention4(const MatPtrT<float>& q,
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos, const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos, const uint32_t* HWY_RESTRICT last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx, const size_t min_last_pos, const size_t max_last_pos,
const LayerWeightsPtrs& layer, const AttentionActivations& activations, const MatPtrT<KV_t>& v, const size_t layer_idx,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) { ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4); GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
@ -589,14 +592,15 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the // grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency. // 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism, void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx,
AttentionActivations& activations, QBatch& qbatch, const MatPtrT<float>& query_norm_scale,
AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) { ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive); GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx, RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
layer, activations, ctx); query_norm_scale, layer_idx, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query // A "head group" in the context of GQA refers to a collection of query
@ -732,12 +736,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// is used above to catch all cases where qT will be used. // is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k, TileFlashAttention(activations.q, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations, max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker); activations.att_out, out_offsets, ctx, worker);
} else if (kVTileSize == 4) { } else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, k, TileFlashAttention4(activations.q, q_offsets, k,
start_positions[offset], last_pos, min_last_pos, start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations, max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker); activations.att_out, out_offsets, ctx, worker);
} else { } else {
HWY_UNREACHABLE; HWY_UNREACHABLE;
@ -746,7 +750,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
} else { } else {
SingleFlashAttention(start_positions[offset], last_pos[offset], SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v, activations.q.Row(0) + q_offsets[offset], k, v,
layer_idx, layer, activations, layer_idx, activations,
activations.att_out.Row(0) + out_offsets[offset], activations.att_out.Row(0) + out_offsets[offset],
ctx, worker); ctx, worker);
} }

View File

@ -28,17 +28,16 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ #define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \ void RMSNormAndPositionalEncoding( \
MatPtrT<KV_t>& q, size_t layer_idx, \ size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const LayerWeightsPtrs& layer, \ const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const AttentionActivations& activations, \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
ThreadingContext& ctx); \
\ \
void SingleFlashAttention(size_t start_pos, size_t last_pos, \ void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \ const float* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \ const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, \
const AttentionActivations& activations, \ const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \ float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \ ThreadingContext& ctx, size_t worker); \
\ \
@ -46,8 +45,9 @@ namespace gcpp {
size_t total_tasks, size_t target_parallelism); \ size_t total_tasks, size_t target_parallelism); \
\ \
void FlashAttention(size_t num_tokens, size_t target_parallelism, \ void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, \
AttentionActivations& activations, QBatch& qbatch, \ const MatPtrT<float>& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \ ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -122,8 +122,9 @@ void TestFlashAttention(size_t target_parallelism) {
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries); QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
const size_t batch_size = kOuter; const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs; std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention(config, layer_config, batch_size, kOuter, AttentionActivations attention_storage(config, layer_config, batch_size,
ctx.allocator, row_ptrs); kOuter, ctx.allocator, row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner); ASSERT_EQ(qkv_dim, kInner);
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
@ -145,7 +146,8 @@ void TestFlashAttention(size_t target_parallelism) {
SetMat(h + layer_config.heads * 2, v); SetMat(h + layer_config.heads * 2, v);
} }
SetMat(1, attention.q); SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx); DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
qbatch, ctx);
// Copy the output to saved_att to allow for comparison. // Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q); SetMat(1, attention.q);
@ -158,8 +160,8 @@ void TestFlashAttention(size_t target_parallelism) {
total_tasks, target_parallelism); total_tasks, target_parallelism);
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n", printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize); target_parallelism, kNF, kVTileSize);
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention, FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
qbatch, ctx); attention, qbatch, ctx);
AssertClose(attention.att_out, *saved_att); AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults(); ctx.profiler.PrintResults();
} }

View File

@ -51,7 +51,7 @@ struct PerQuery {
// attention in Paligemma. // attention in Paligemma.
size_t prefix_end; size_t prefix_end;
KVCache& kv_cache; KVCachePtr kv_cache;
// Previous token generated for this query, or the last prompt token. Will be // Previous token generated for this query, or the last prompt token. Will be
// fed into the next Transformer() call. // fed into the next Transformer() call.
@ -64,7 +64,7 @@ struct AllQueries {
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches) { const hwy::Span<KVCachePtr>& kv_caches) {
per_query_.reserve(kv_caches.size()); per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) { for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
@ -78,11 +78,16 @@ struct AllQueries {
} }
} }
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches)
: AllQueries(prompt, pos, prefix_end,
hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {}
// Batch of queries with initial position set to zero. Causal attention // Batch of queries with initial position set to zero. Causal attention
// is requested via empty or all-zero `prefix_end`. // is requested via empty or all-zero `prefix_end`.
AllQueries( AllQueries(
const hwy::Span<const PromptTokens>& prompts, const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches, const hwy::Span<KVCachePtr>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) { const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size()); HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
@ -99,6 +104,13 @@ struct AllQueries {
} }
} }
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>())
: AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)),
prefix_end) {}
void Reserve(size_t size) { per_query_.reserve(size); } void Reserve(size_t size) { per_query_.reserve(size); }
void Append(const PerQuery& query) { per_query_.push_back(query); } void Append(const PerQuery& query) { per_query_.push_back(query); }
@ -156,7 +168,7 @@ class QBatch {
size_t PrefixEnd(size_t qi) const { size_t PrefixEnd(size_t qi) const {
return queries_[QueryIdx(qi)].prefix_end; return queries_[QueryIdx(qi)].prefix_end;
} }
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private: private:

View File

@ -16,6 +16,7 @@
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include <stddef.h> #include <stddef.h>
#include <vector>
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma_args.h" #include "gemma/gemma_args.h"
@ -54,4 +55,13 @@ KVCache KVCache::Copy() {
return copy; return copy;
} }
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache});
}
return ptrs;
}
} // namespace gcpp } // namespace gcpp

View File

@ -17,6 +17,7 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h> #include <stddef.h>
#include <vector>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs #include "gemma/gemma_args.h" // InferenceArgs
@ -46,6 +47,15 @@ struct KVCache {
KVCache(const Extents2D& kv_extents, const Allocator& allocator); KVCache(const Extents2D& kv_extents, const Allocator& allocator);
}; };
// A non-owning view of a KVCache.
struct KVCachePtr {
size_t SeqLen() const { return kv_cache.Rows(); }
MatPtrT<KV_t> kv_cache;
};
// Convenience function to create views into KVCaches.
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_

View File

@ -454,7 +454,7 @@ void TestRopeAndMulBy() {
x.Row(0)[i] = random_float(); x.Row(0)[i] = random_float();
} }
const float qmul = AttentionActivations::ChooseQueryScale(config); const float qmul = ChooseQueryScale(config);
constexpr float kmul = 1.0f; constexpr float kmul = 1.0f;
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator); MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);

View File

@ -284,6 +284,9 @@ class MatPtrT : public MatPtr {
public: public:
using T = MatT; using T = MatT;
// Default constructor for use with uninitialized views.
MatPtrT() = default;
// Called by `MatStorageT`. // Called by `MatStorageT`.
MatPtrT(const char* name, Extents2D extents) MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {} : MatPtr(name, TypeEnum<MatT>(), extents) {}