[Gemma.cpp] Allows non-owned arguments for attention methods.

* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`.
* Updates `QBatch` to hold  non-owning `MatPtr`s to the kv caches.
* Enables the `MatPtrT` default constructor for simpler initializations.
* Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor.

PiperOrigin-RevId: 824584177
This commit is contained in:
Biruk Mammo 2025-10-27 10:42:46 -07:00 committed by Copybara-Service
parent 86200ce224
commit 5a05857deb
12 changed files with 185 additions and 120 deletions

View File

@ -139,7 +139,6 @@ cc_test(
":kv_cache",
":mat",
":matmul",
":ops",
":threading_context",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep

View File

@ -31,26 +31,24 @@
namespace gcpp {
struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
const LayerConfig& layer_config = config.layer_configs[0];
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f /
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
}
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
const LayerConfig& layer_config = config.layer_configs[0];
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f /
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
}
struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: config(config),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache.
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
q(MatFactory("q", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
@ -76,11 +74,7 @@ struct AttentionActivations {
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(layer_config.heads)),
query_scale(ChooseQueryScale(config)) {
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
// Batch size can be 0 in experimental code so do not assert.
if (batch_size == 0) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
@ -108,9 +102,7 @@ struct AttentionActivations {
att_sums.OverrideRows(batch_size);
}
const ModelConfig& config;
MatStorageT<float> q; // query
MatStorageT<float> q; // query
MatStorageT<float> q_T; // Transposed to maximize attention speed.
MatStorageT<float> pre_att_rms_out;
@ -122,9 +114,39 @@ struct AttentionActivations {
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
};
// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
: config(config),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
q = activations.q;
q_T = activations.q_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_sums = activations.att_sums;
inv_timescale = activations.inv_timescale;
inv_timescale_global = activations.inv_timescale_global;
}
const ModelConfig& config;
MatPtrT<float> q;
MatPtrT<float> q_T;
MatPtrT<float> pre_att_rms_out;
MatPtrT<float> att;
MatPtrT<float> att_out;
MatPtrT<BF16> att_sums;
MatPtrT<float> inv_timescale;
MatPtrT<float> inv_timescale_global;
hwy::Divisor div_seq_len;
// Unfortunately, some models have had non-power-of-two heads.
hwy::Divisor div_heads;
float query_scale;
};
@ -150,8 +172,9 @@ struct Activations {
ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
row_ptrs) {
attention_storage(config, layer_config, batch_size, seq_len,
ctx.allocator, row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
@ -179,12 +202,12 @@ struct Activations {
C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size);
attention.SetBatchSize(batch_size);
attention_storage.SetBatchSize(batch_size);
}
const LayerConfig& layer_config;
MatStorageT<float> x; // input
MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
@ -195,7 +218,8 @@ struct Activations {
MatStorageT<BF16> C2;
MatStorageT<float> ffw_out;
AttentionActivations attention;
AttentionActivations attention_storage;
AttentionActivationsPtrs attention;
};
} // namespace gcpp

View File

@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
}
void PositionalEncodingQK(float* qk, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx, const size_t worker,
const size_t pos, const float mul) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const PostQKType& post_qk = layer.layer_config.post_qk;
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim;
const PostQKType& post_qk = layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1();
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV(
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const AttentionActivations& activations, float* HWY_RESTRICT att,
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
if (query_norm_scale.HasPtr()) {
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
layer.layer_config.qkv_dim, ctx, worker);
layer_config.qkv_dim, ctx, worker);
});
}
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
query_scale);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
}
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
ThreadingContext& ctx) {
const MatPtrT<float>& query_norm_scale,
AttentionActivationsPtrs& activations,
QBatch& qbatch, ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query
@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
layer, activations, att, att_out, ctx, worker);
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
query_norm_scale, layer_idx, activations, att,
att_out, ctx, worker);
};
{
@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations,
AttentionActivationsPtrs& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
});
}
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
worker, pos, /*mul=*/1.0f);
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
pos, /*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
});
@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations,
AttentionActivationsPtrs& activations,
MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config;
@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
AttentionActivationsPtrs& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
if (flags & kAttentionUseOld) {
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx);
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
activations, qbatch, env.ctx);
} else {
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
layer_idx, layer, activations, qbatch, env.ctx);
layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx);
}
SumHeads(layer, activations, env);
}

View File

@ -29,8 +29,7 @@ namespace gcpp {
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
const AttentionActivationsPtrs& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \
float mul); \
\
@ -39,18 +38,18 @@ namespace gcpp {
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, \
const MatPtrT<float>& query_norm_scale, \
AttentionActivationsPtrs& activations, \
QBatch& qbatch, ThreadingContext& ctx); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -30,7 +30,6 @@
#include "gemma/activations.h"
#include "gemma/configs.h" // kMaxQKVDim
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/profiler.h"
@ -91,32 +90,33 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
// Updates q in place for RMSNorm and positional encoding.
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
MatPtrT<KV_t>& q, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
MatPtrT<float>& q,
const MatPtrT<float>& query_norm_scale,
const size_t layer_idx,
const AttentionActivationsPtrs& activations,
ThreadingContext& ctx) {
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
size_t qi = div_qbatch.Remainder(task);
size_t batch_idx = div_qbatch.Divide(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
for (size_t h = 0; h < layer_config.heads; ++h) {
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT q_row =
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
if (query_norm_scale.HasPtr()) {
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
layer.layer_config.qkv_dim, ctx, worker);
layer_config.qkv_dim, ctx, worker);
});
}
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker,
pos, query_scale);
PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos,
query_scale);
}
};
{
@ -154,8 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer,
const AttentionActivations& activations,
const AttentionActivationsPtrs& activations,
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
@ -265,15 +264,17 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
void TileFlashAttention(const MatPtrT<float>& q,
const uint32_t* HWY_RESTRICT q_offsets,
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
constexpr int kHTileSize = kNFx8HTileSize;
using DF = hn::ScalableTag<float>;
@ -419,14 +420,16 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
// max_last_pos].
void TileFlashAttention4(
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
void TileFlashAttention4(const MatPtrT<float>& q,
const uint32_t* HWY_RESTRICT q_offsets,
const MatPtrT<KV_t>& k, const size_t start_pos,
const uint32_t* HWY_RESTRICT last_pos,
const size_t min_last_pos, const size_t max_last_pos,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations,
MatPtrT<float>& att_out,
const uint32_t* HWY_RESTRICT out_offsets,
ThreadingContext& ctx, const size_t worker) {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
using DF = hn::ScalableTag<float>;
const DF df;
@ -589,14 +592,15 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
// grouped together so that mode 1 or 2 can be used, and choosing which of the
// 3 modes to use for best efficiency.
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
const size_t layer_idx,
const MatPtrT<float>& query_norm_scale,
AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) {
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
layer, activations, ctx);
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
query_norm_scale, layer_idx, activations, ctx);
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
const size_t qkv_dim = layer_config.qkv_dim;
// A "head group" in the context of GQA refers to a collection of query
@ -732,12 +736,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
// is used above to catch all cases where qT will be used.
TileFlashAttention(activations.q, q_offsets, qT, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations,
max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker);
} else if (kVTileSize == 4) {
TileFlashAttention4(activations.q, q_offsets, k,
start_positions[offset], last_pos, min_last_pos,
max_last_pos, v, layer_idx, layer, activations,
max_last_pos, v, layer_idx, activations,
activations.att_out, out_offsets, ctx, worker);
} else {
HWY_UNREACHABLE;
@ -746,7 +750,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
} else {
SingleFlashAttention(start_positions[offset], last_pos[offset],
activations.q.Row(0) + q_offsets[offset], k, v,
layer_idx, layer, activations,
layer_idx, activations,
activations.att_out.Row(0) + out_offsets[offset],
ctx, worker);
}

View File

@ -28,17 +28,16 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \
MatPtrT<KV_t>& q, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
ThreadingContext& ctx); \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const float* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const AttentionActivations& activations, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
@ -46,8 +45,9 @@ namespace gcpp {
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
size_t layer_idx, \
const MatPtrT<float>& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -122,8 +122,9 @@ void TestFlashAttention(size_t target_parallelism) {
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention(config, layer_config, batch_size, kOuter,
ctx.allocator, row_ptrs);
AttentionActivations attention_storage(config, layer_config, batch_size,
kOuter, ctx.allocator, row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner);
const hwy::Divisor div_qbatch(qbatch.Size());
@ -145,7 +146,8 @@ void TestFlashAttention(size_t target_parallelism) {
SetMat(h + layer_config.heads * 2, v);
}
SetMat(1, attention.q);
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx);
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
qbatch, ctx);
// Copy the output to saved_att to allow for comparison.
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
SetMat(1, attention.q);
@ -158,8 +160,8 @@ void TestFlashAttention(size_t target_parallelism) {
total_tasks, target_parallelism);
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
target_parallelism, kNF, kVTileSize);
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
qbatch, ctx);
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
attention, qbatch, ctx);
AssertClose(attention.att_out, *saved_att);
ctx.profiler.PrintResults();
}

View File

@ -51,7 +51,7 @@ struct PerQuery {
// attention in Paligemma.
size_t prefix_end;
KVCache& kv_cache;
KVCachePtr kv_cache;
// Previous token generated for this query, or the last prompt token. Will be
// fed into the next Transformer() call.
@ -64,7 +64,7 @@ struct AllQueries {
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches) {
const hwy::Span<KVCachePtr>& kv_caches) {
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
@ -78,11 +78,16 @@ struct AllQueries {
}
}
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches)
: AllQueries(prompt, pos, prefix_end,
hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {}
// Batch of queries with initial position set to zero. Causal attention
// is requested via empty or all-zero `prefix_end`.
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<KVCachePtr>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
@ -99,6 +104,13 @@ struct AllQueries {
}
}
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>())
: AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)),
prefix_end) {}
void Reserve(size_t size) { per_query_.reserve(size); }
void Append(const PerQuery& query) { per_query_.push_back(query); }
@ -156,7 +168,7 @@ class QBatch {
size_t PrefixEnd(size_t qi) const {
return queries_[QueryIdx(qi)].prefix_end;
}
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private:

View File

@ -16,6 +16,7 @@
#include "gemma/kv_cache.h"
#include <stddef.h>
#include <vector>
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
@ -54,4 +55,13 @@ KVCache KVCache::Copy() {
return copy;
}
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
std::vector<KVCachePtr> ptrs;
ptrs.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache});
}
return ptrs;
}
} // namespace gcpp

View File

@ -17,6 +17,7 @@
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
#include <stddef.h>
#include <vector>
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // InferenceArgs
@ -46,6 +47,15 @@ struct KVCache {
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
};
// A non-owning view of a KVCache.
struct KVCachePtr {
size_t SeqLen() const { return kv_cache.Rows(); }
MatPtrT<KV_t> kv_cache;
};
// Convenience function to create views into KVCaches.
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_

View File

@ -454,7 +454,7 @@ void TestRopeAndMulBy() {
x.Row(0)[i] = random_float();
}
const float qmul = AttentionActivations::ChooseQueryScale(config);
const float qmul = ChooseQueryScale(config);
constexpr float kmul = 1.0f;
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);

View File

@ -284,6 +284,9 @@ class MatPtrT : public MatPtr {
public:
using T = MatT;
// Default constructor for use with uninitialized views.
MatPtrT() = default;
// Called by `MatStorageT`.
MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {}