mirror of https://github.com/google/gemma.cpp.git
[Gemma.cpp] Allows non-owned arguments for attention methods.
* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`. * Updates `QBatch` to hold non-owning `MatPtr`s to the kv caches. * Enables the `MatPtrT` default constructor for simpler initializations. * Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor. PiperOrigin-RevId: 824584177
This commit is contained in:
parent
86200ce224
commit
5a05857deb
|
|
@ -139,7 +139,6 @@ cc_test(
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":mat",
|
":mat",
|
||||||
":matmul",
|
":matmul",
|
||||||
":ops",
|
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
|
|
||||||
|
|
@ -31,26 +31,24 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
struct AttentionActivations {
|
// Returns the scale value to use for the query in the attention computation.
|
||||||
// Returns the scale value to use for the query in the attention computation.
|
// Also called by ops_test.
|
||||||
// Also called by ops_test.
|
static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||||
static inline float ChooseQueryScale(const ModelConfig& config) {
|
const LayerConfig& layer_config = config.layer_configs[0];
|
||||||
const LayerConfig& layer_config = config.layer_configs[0];
|
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
return 1.0f /
|
||||||
return 1.0f /
|
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
|
||||||
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
|
// QueryScaleType::SqrtKeySize
|
||||||
// QueryScaleType::SqrtKeySize
|
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
|
||||||
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
struct AttentionActivations {
|
||||||
AttentionActivations(
|
AttentionActivations(
|
||||||
const ModelConfig& config, const LayerConfig& layer_config,
|
const ModelConfig& config, const LayerConfig& layer_config,
|
||||||
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: config(config),
|
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||||
|
// MHA and does not use an external KV cache.
|
||||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
|
||||||
// and does not use an external KV cache.
|
|
||||||
q(MatFactory("q", batch_size,
|
q(MatFactory("q", batch_size,
|
||||||
config.vocab_size == 0
|
config.vocab_size == 0
|
||||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||||
|
|
@ -76,11 +74,7 @@ struct AttentionActivations {
|
||||||
layer_config.post_qk == PostQKType::HalfRope)),
|
layer_config.post_qk == PostQKType::HalfRope)),
|
||||||
inv_timescale_global(CreateInvTimescale(
|
inv_timescale_global(CreateInvTimescale(
|
||||||
allocator, layer_config.qkv_dim,
|
allocator, layer_config.qkv_dim,
|
||||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
|
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) {
|
||||||
|
|
||||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
|
||||||
div_heads(static_cast<uint32_t>(layer_config.heads)),
|
|
||||||
query_scale(ChooseQueryScale(config)) {
|
|
||||||
// Batch size can be 0 in experimental code so do not assert.
|
// Batch size can be 0 in experimental code so do not assert.
|
||||||
if (batch_size == 0) {
|
if (batch_size == 0) {
|
||||||
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
|
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
|
||||||
|
|
@ -108,9 +102,7 @@ struct AttentionActivations {
|
||||||
att_sums.OverrideRows(batch_size);
|
att_sums.OverrideRows(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelConfig& config;
|
MatStorageT<float> q; // query
|
||||||
|
|
||||||
MatStorageT<float> q; // query
|
|
||||||
MatStorageT<float> q_T; // Transposed to maximize attention speed.
|
MatStorageT<float> q_T; // Transposed to maximize attention speed.
|
||||||
|
|
||||||
MatStorageT<float> pre_att_rms_out;
|
MatStorageT<float> pre_att_rms_out;
|
||||||
|
|
@ -122,9 +114,39 @@ struct AttentionActivations {
|
||||||
// Rope
|
// Rope
|
||||||
MatStorageT<float> inv_timescale;
|
MatStorageT<float> inv_timescale;
|
||||||
MatStorageT<float> inv_timescale_global;
|
MatStorageT<float> inv_timescale_global;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A non-owning view of AttentionActivations.
|
||||||
|
struct AttentionActivationsPtrs {
|
||||||
|
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
|
||||||
|
: config(config),
|
||||||
|
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||||
|
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
|
||||||
|
query_scale(ChooseQueryScale(config)) {}
|
||||||
|
|
||||||
|
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
|
||||||
|
const AttentionActivations& activations)
|
||||||
|
: AttentionActivationsPtrs(config, seq_len) {
|
||||||
|
q = activations.q;
|
||||||
|
q_T = activations.q_T;
|
||||||
|
pre_att_rms_out = activations.pre_att_rms_out;
|
||||||
|
att = activations.att;
|
||||||
|
att_out = activations.att_out;
|
||||||
|
att_sums = activations.att_sums;
|
||||||
|
inv_timescale = activations.inv_timescale;
|
||||||
|
inv_timescale_global = activations.inv_timescale_global;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ModelConfig& config;
|
||||||
|
MatPtrT<float> q;
|
||||||
|
MatPtrT<float> q_T;
|
||||||
|
MatPtrT<float> pre_att_rms_out;
|
||||||
|
MatPtrT<float> att;
|
||||||
|
MatPtrT<float> att_out;
|
||||||
|
MatPtrT<BF16> att_sums;
|
||||||
|
MatPtrT<float> inv_timescale;
|
||||||
|
MatPtrT<float> inv_timescale_global;
|
||||||
hwy::Divisor div_seq_len;
|
hwy::Divisor div_seq_len;
|
||||||
// Unfortunately, some models have had non-power-of-two heads.
|
|
||||||
hwy::Divisor div_heads;
|
hwy::Divisor div_heads;
|
||||||
float query_scale;
|
float query_scale;
|
||||||
};
|
};
|
||||||
|
|
@ -150,8 +172,9 @@ struct Activations {
|
||||||
ffw_out(
|
ffw_out(
|
||||||
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||||
|
|
||||||
attention(config, layer_config, batch_size, seq_len, ctx.allocator,
|
attention_storage(config, layer_config, batch_size, seq_len,
|
||||||
row_ptrs) {
|
ctx.allocator, row_ptrs),
|
||||||
|
attention(config, seq_len, attention_storage) {
|
||||||
HWY_ASSERT(batch_size != 0);
|
HWY_ASSERT(batch_size != 0);
|
||||||
|
|
||||||
// For MatMul outputs, precompute their row pointers.
|
// For MatMul outputs, precompute their row pointers.
|
||||||
|
|
@ -179,12 +202,12 @@ struct Activations {
|
||||||
C2.OverrideRows(batch_size);
|
C2.OverrideRows(batch_size);
|
||||||
ffw_out.OverrideRows(batch_size);
|
ffw_out.OverrideRows(batch_size);
|
||||||
|
|
||||||
attention.SetBatchSize(batch_size);
|
attention_storage.SetBatchSize(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const LayerConfig& layer_config;
|
const LayerConfig& layer_config;
|
||||||
|
|
||||||
MatStorageT<float> x; // input
|
MatStorageT<float> x; // input
|
||||||
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||||
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
|
MatStorageT<float> logits; // TODO: BF16 after Softmax supports that.
|
||||||
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
||||||
|
|
@ -195,7 +218,8 @@ struct Activations {
|
||||||
MatStorageT<BF16> C2;
|
MatStorageT<BF16> C2;
|
||||||
MatStorageT<float> ffw_out;
|
MatStorageT<float> ffw_out;
|
||||||
|
|
||||||
AttentionActivations attention;
|
AttentionActivations attention_storage;
|
||||||
|
AttentionActivationsPtrs attention;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -73,12 +73,12 @@ static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
|
||||||
}
|
}
|
||||||
|
|
||||||
void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const AttentionActivationsPtrs& activations,
|
||||||
const AttentionActivations& activations,
|
|
||||||
ThreadingContext& ctx, const size_t worker,
|
ThreadingContext& ctx, const size_t worker,
|
||||||
const size_t pos, const float mul) {
|
const size_t pos, const float mul) {
|
||||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
const PostQKType& post_qk = layer_config.post_qk;
|
||||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||||
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
|
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
|
||||||
|
|
@ -130,23 +130,23 @@ static HWY_INLINE void WeightedSumV(
|
||||||
void SingleDotSoftmaxWeightedSum(
|
void SingleDotSoftmaxWeightedSum(
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const MatPtrT<float>& query_norm_scale, const size_t layer_idx,
|
||||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att,
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx, const size_t worker) {
|
||||||
const float att_cap = activations.config.att_cap;
|
const float att_cap = activations.config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
const size_t seq_len =
|
const size_t seq_len =
|
||||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||||
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (layer.query_norm_scale.HasPtr()) {
|
if (query_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q,
|
||||||
layer.layer_config.qkv_dim, ctx, worker);
|
layer_config.qkv_dim, ctx, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(q, layer_idx, layer, activations, ctx, worker, pos,
|
PositionalEncodingQK(q, layer_idx, activations, ctx, worker, pos,
|
||||||
query_scale);
|
query_scale);
|
||||||
|
|
||||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att, ctx, worker);
|
||||||
|
|
@ -169,13 +169,13 @@ size_t StartPos(size_t pos, const ModelConfig& config, size_t layer_idx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const MatPtrT<float>& query_norm_scale,
|
||||||
AttentionActivations& activations, QBatch& qbatch,
|
AttentionActivationsPtrs& activations,
|
||||||
ThreadingContext& ctx) {
|
QBatch& qbatch, ThreadingContext& ctx) {
|
||||||
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
|
GCPP_ZONE(ctx, 0, Zones::kGenAttentionDotSoftmaxWeightedSumInclusive);
|
||||||
|
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
// A "head group" in the context of GQA refers to a collection of query
|
// A "head group" in the context of GQA refers to a collection of query
|
||||||
|
|
@ -223,8 +223,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, layer_idx,
|
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
||||||
layer, activations, att, att_out, ctx, worker);
|
query_norm_scale, layer_idx, activations, att,
|
||||||
|
att_out, ctx, worker);
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
@ -245,7 +246,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
// Fills activations.q and writes to KV cache.
|
// Fills activations.q and writes to KV cache.
|
||||||
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations,
|
AttentionActivationsPtrs& activations,
|
||||||
const QBatch& qbatch, const int flags,
|
const QBatch& qbatch, const int flags,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
|
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(),
|
||||||
|
|
@ -312,8 +313,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PositionalEncodingQK(kv_f32, layer_idx, layer, activations, env.ctx,
|
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
|
||||||
worker, pos, /*mul=*/1.0f);
|
pos, /*mul=*/1.0f);
|
||||||
CompressPerThread tls;
|
CompressPerThread tls;
|
||||||
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
|
||||||
});
|
});
|
||||||
|
|
@ -322,7 +323,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations,
|
AttentionActivationsPtrs& activations,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
|
@ -340,7 +341,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||||
|
|
||||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations, QBatch& qbatch,
|
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||||
MatMulEnv& env, int flags) {
|
MatMulEnv& env, int flags) {
|
||||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
|
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttention);
|
||||||
|
|
||||||
|
|
@ -352,13 +353,14 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
|
|
||||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||||
if (flags & kAttentionUseOld) {
|
if (flags & kAttentionUseOld) {
|
||||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer.query_norm_scale,
|
||||||
env.ctx);
|
activations, qbatch, env.ctx);
|
||||||
} else {
|
} else {
|
||||||
// * 2 does not help on Turin.
|
// * 2 does not help on Turin.
|
||||||
FlashAttention(num_tokens,
|
FlashAttention(num_tokens,
|
||||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
||||||
layer_idx, layer, activations, qbatch, env.ctx);
|
layer_idx, layer.query_norm_scale, activations, qbatch,
|
||||||
|
env.ctx);
|
||||||
}
|
}
|
||||||
SumHeads(layer, activations, env);
|
SumHeads(layer, activations, env);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,8 +29,7 @@ namespace gcpp {
|
||||||
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
void PositionalEncodingQK(float* qk, size_t layer_idx, \
|
||||||
const LayerWeightsPtrs& layer, \
|
const AttentionActivationsPtrs& activations, \
|
||||||
const AttentionActivations& activations, \
|
|
||||||
ThreadingContext& ctx, size_t worker, size_t pos, \
|
ThreadingContext& ctx, size_t worker, size_t pos, \
|
||||||
float mul); \
|
float mul); \
|
||||||
\
|
\
|
||||||
|
|
@ -39,18 +38,18 @@ namespace gcpp {
|
||||||
void SingleDotSoftmaxWeightedSum( \
|
void SingleDotSoftmaxWeightedSum( \
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||||
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
float* HWY_RESTRICT q, const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
|
||||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
const AttentionActivationsPtrs& activations, float* HWY_RESTRICT att, \
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx, size_t worker); \
|
||||||
\
|
\
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||||
const LayerWeightsPtrs& layer, \
|
const MatPtrT<float>& query_norm_scale, \
|
||||||
AttentionActivations& activations, \
|
AttentionActivationsPtrs& activations, \
|
||||||
QBatch& qbatch, ThreadingContext& ctx); \
|
QBatch& qbatch, ThreadingContext& ctx); \
|
||||||
\
|
\
|
||||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||||
const LayerWeightsPtrs& layer, \
|
const LayerWeightsPtrs& layer, \
|
||||||
AttentionActivations& activations, QBatch& qbatch, \
|
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||||
MatMulEnv& env, int flags); \
|
MatMulEnv& env, int flags); \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,6 @@
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/configs.h" // kMaxQKVDim
|
#include "gemma/configs.h" // kMaxQKVDim
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
|
@ -91,32 +90,33 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||||
|
|
||||||
// Updates q in place for RMSNorm and positional encoding.
|
// Updates q in place for RMSNorm and positional encoding.
|
||||||
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
MatPtrT<KV_t>& q, const size_t layer_idx,
|
MatPtrT<float>& q,
|
||||||
const LayerWeightsPtrs& layer,
|
const MatPtrT<float>& query_norm_scale,
|
||||||
const AttentionActivations& activations,
|
const size_t layer_idx,
|
||||||
|
const AttentionActivationsPtrs& activations,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionRmsNormAndPositionalEncoding);
|
||||||
size_t qi = div_qbatch.Remainder(task);
|
size_t qi = div_qbatch.Remainder(task);
|
||||||
size_t batch_idx = div_qbatch.Divide(task);
|
size_t batch_idx = div_qbatch.Divide(task);
|
||||||
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
for (size_t h = 0; h < layer_config.heads; ++h) {
|
||||||
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
||||||
// Find the token position in the query and calculate
|
// Find the token position in the query and calculate
|
||||||
// the range of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
float* HWY_RESTRICT q_row =
|
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
|
||||||
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (layer.query_norm_scale.HasPtr()) {
|
if (query_norm_scale.HasPtr()) {
|
||||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
CallUpcasted(&query_norm_scale, [&](const auto* weights_t) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, q_row,
|
||||||
layer.layer_config.qkv_dim, ctx, worker);
|
layer_config.qkv_dim, ctx, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx, worker,
|
PositionalEncodingQK(q_row, layer_idx, activations, ctx, worker, pos,
|
||||||
pos, query_scale);
|
query_scale);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
|
|
@ -154,8 +154,7 @@ void HWY_INLINE SingleFlashAttentionStep(float x, float cap, float& old_max,
|
||||||
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
||||||
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
const float* HWY_RESTRICT q, const MatPtrT<KV_t>& k,
|
||||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const AttentionActivationsPtrs& activations,
|
||||||
const AttentionActivations& activations,
|
|
||||||
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
float* HWY_RESTRICT att_out, ThreadingContext& ctx,
|
||||||
const size_t worker) {
|
const size_t worker) {
|
||||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionSingleFlashAttention);
|
||||||
|
|
@ -265,15 +264,17 @@ VF HWY_INLINE ElementwiseSumOf8(DF df, const VF& x0, const VF& x1, const VF& x2,
|
||||||
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
|
// Sweeps a tile of NF Q rows by 8 K timesteps accumulators from start_pos to
|
||||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||||
// max_last_pos].
|
// max_last_pos].
|
||||||
void TileFlashAttention(
|
void TileFlashAttention(const MatPtrT<float>& q,
|
||||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
const StridedView<float>& qT, const MatPtrT<KV_t>& k,
|
||||||
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos,
|
const size_t start_pos,
|
||||||
const size_t min_last_pos, const size_t max_last_pos,
|
const uint32_t* HWY_RESTRICT last_pos,
|
||||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
const size_t min_last_pos, const size_t max_last_pos,
|
||||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
const AttentionActivationsPtrs& activations,
|
||||||
ThreadingContext& ctx, const size_t worker) {
|
MatPtrT<float>& att_out,
|
||||||
|
const uint32_t* HWY_RESTRICT out_offsets,
|
||||||
|
ThreadingContext& ctx, const size_t worker) {
|
||||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention);
|
||||||
constexpr int kHTileSize = kNFx8HTileSize;
|
constexpr int kHTileSize = kNFx8HTileSize;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
@ -419,14 +420,16 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||||
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
|
// Sweeps a tile of 4 Q rows by NF K timesteps accumulators from start_pos to
|
||||||
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
// min_last_pos, then sweeps the remaining timesteps in the range (min_last_pos,
|
||||||
// max_last_pos].
|
// max_last_pos].
|
||||||
void TileFlashAttention4(
|
void TileFlashAttention4(const MatPtrT<float>& q,
|
||||||
const MatPtrT<float>& q, const uint32_t* HWY_RESTRICT q_offsets,
|
const uint32_t* HWY_RESTRICT q_offsets,
|
||||||
const MatPtrT<KV_t>& k, const size_t start_pos,
|
const MatPtrT<KV_t>& k, const size_t start_pos,
|
||||||
const uint32_t* HWY_RESTRICT last_pos, const size_t min_last_pos,
|
const uint32_t* HWY_RESTRICT last_pos,
|
||||||
const size_t max_last_pos, const MatPtrT<KV_t>& v, const size_t layer_idx,
|
const size_t min_last_pos, const size_t max_last_pos,
|
||||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations,
|
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets,
|
const AttentionActivationsPtrs& activations,
|
||||||
ThreadingContext& ctx, const size_t worker) {
|
MatPtrT<float>& att_out,
|
||||||
|
const uint32_t* HWY_RESTRICT out_offsets,
|
||||||
|
ThreadingContext& ctx, const size_t worker) {
|
||||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionTileFlashAttention4);
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
const DF df;
|
const DF df;
|
||||||
|
|
@ -589,14 +592,15 @@ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens,
|
||||||
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
// grouped together so that mode 1 or 2 can be used, and choosing which of the
|
||||||
// 3 modes to use for best efficiency.
|
// 3 modes to use for best efficiency.
|
||||||
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const size_t layer_idx,
|
||||||
AttentionActivations& activations, QBatch& qbatch,
|
const MatPtrT<float>& query_norm_scale,
|
||||||
|
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
GCPP_ZONE(ctx, 0, Zones::kFlashAttentionInclusive);
|
||||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q, layer_idx,
|
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
||||||
layer, activations, ctx);
|
query_norm_scale, layer_idx, activations, ctx);
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = activations.config.layer_configs[layer_idx];
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
// A "head group" in the context of GQA refers to a collection of query
|
// A "head group" in the context of GQA refers to a collection of query
|
||||||
|
|
@ -732,12 +736,12 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
// is used above to catch all cases where qT will be used.
|
// is used above to catch all cases where qT will be used.
|
||||||
TileFlashAttention(activations.q, q_offsets, qT, k,
|
TileFlashAttention(activations.q, q_offsets, qT, k,
|
||||||
start_positions[offset], last_pos, min_last_pos,
|
start_positions[offset], last_pos, min_last_pos,
|
||||||
max_last_pos, v, layer_idx, layer, activations,
|
max_last_pos, v, layer_idx, activations,
|
||||||
activations.att_out, out_offsets, ctx, worker);
|
activations.att_out, out_offsets, ctx, worker);
|
||||||
} else if (kVTileSize == 4) {
|
} else if (kVTileSize == 4) {
|
||||||
TileFlashAttention4(activations.q, q_offsets, k,
|
TileFlashAttention4(activations.q, q_offsets, k,
|
||||||
start_positions[offset], last_pos, min_last_pos,
|
start_positions[offset], last_pos, min_last_pos,
|
||||||
max_last_pos, v, layer_idx, layer, activations,
|
max_last_pos, v, layer_idx, activations,
|
||||||
activations.att_out, out_offsets, ctx, worker);
|
activations.att_out, out_offsets, ctx, worker);
|
||||||
} else {
|
} else {
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
|
|
@ -746,7 +750,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
} else {
|
} else {
|
||||||
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
SingleFlashAttention(start_positions[offset], last_pos[offset],
|
||||||
activations.q.Row(0) + q_offsets[offset], k, v,
|
activations.q.Row(0) + q_offsets[offset], k, v,
|
||||||
layer_idx, layer, activations,
|
layer_idx, activations,
|
||||||
activations.att_out.Row(0) + out_offsets[offset],
|
activations.att_out.Row(0) + out_offsets[offset],
|
||||||
ctx, worker);
|
ctx, worker);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,17 +28,16 @@ namespace gcpp {
|
||||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||||
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
void RMSNormAndPositionalEncoding(size_t num_tokens, const QBatch& qbatch, \
|
void RMSNormAndPositionalEncoding( \
|
||||||
MatPtrT<KV_t>& q, size_t layer_idx, \
|
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||||
const LayerWeightsPtrs& layer, \
|
const MatPtrT<float>& query_norm_scale, size_t layer_idx, \
|
||||||
const AttentionActivations& activations, \
|
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||||
ThreadingContext& ctx); \
|
|
||||||
\
|
\
|
||||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||||
const float* HWY_RESTRICT q, \
|
const float* HWY_RESTRICT q, \
|
||||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
size_t layer_idx, \
|
||||||
const AttentionActivations& activations, \
|
const AttentionActivationsPtrs& activations, \
|
||||||
float* HWY_RESTRICT att_out, \
|
float* HWY_RESTRICT att_out, \
|
||||||
ThreadingContext& ctx, size_t worker); \
|
ThreadingContext& ctx, size_t worker); \
|
||||||
\
|
\
|
||||||
|
|
@ -46,8 +45,9 @@ namespace gcpp {
|
||||||
size_t total_tasks, size_t target_parallelism); \
|
size_t total_tasks, size_t target_parallelism); \
|
||||||
\
|
\
|
||||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
size_t layer_idx, \
|
||||||
AttentionActivations& activations, QBatch& qbatch, \
|
const MatPtrT<float>& query_norm_scale, \
|
||||||
|
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||||
ThreadingContext& ctx); \
|
ThreadingContext& ctx); \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
|
||||||
|
|
@ -122,8 +122,9 @@ void TestFlashAttention(size_t target_parallelism) {
|
||||||
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
|
QBatch qbatch(/*start=*/0, /*max_size=*/kOuter, all_queries);
|
||||||
const size_t batch_size = kOuter;
|
const size_t batch_size = kOuter;
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
AttentionActivations attention(config, layer_config, batch_size, kOuter,
|
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||||
ctx.allocator, row_ptrs);
|
kOuter, ctx.allocator, row_ptrs);
|
||||||
|
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
ASSERT_EQ(qkv_dim, kInner);
|
ASSERT_EQ(qkv_dim, kInner);
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
|
|
@ -145,7 +146,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
||||||
SetMat(h + layer_config.heads * 2, v);
|
SetMat(h + layer_config.heads * 2, v);
|
||||||
}
|
}
|
||||||
SetMat(1, attention.q);
|
SetMat(1, attention.q);
|
||||||
DotSoftmaxWeightedSum(tokens.size(), 0, layers, attention, qbatch, ctx);
|
DotSoftmaxWeightedSum(tokens.size(), 0, layers.query_norm_scale, attention,
|
||||||
|
qbatch, ctx);
|
||||||
// Copy the output to saved_att to allow for comparison.
|
// Copy the output to saved_att to allow for comparison.
|
||||||
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator);
|
||||||
SetMat(1, attention.q);
|
SetMat(1, attention.q);
|
||||||
|
|
@ -158,8 +160,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
||||||
total_tasks, target_parallelism);
|
total_tasks, target_parallelism);
|
||||||
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
|
printf("FlashAttention: target_parallelism=%zu, kNF=%zu, kVTileSize=%zu\n",
|
||||||
target_parallelism, kNF, kVTileSize);
|
target_parallelism, kNF, kVTileSize);
|
||||||
FlashAttention(tokens.size(), target_parallelism, 0, layers, attention,
|
FlashAttention(tokens.size(), target_parallelism, 0, layers.query_norm_scale,
|
||||||
qbatch, ctx);
|
attention, qbatch, ctx);
|
||||||
AssertClose(attention.att_out, *saved_att);
|
AssertClose(attention.att_out, *saved_att);
|
||||||
ctx.profiler.PrintResults();
|
ctx.profiler.PrintResults();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ struct PerQuery {
|
||||||
// attention in Paligemma.
|
// attention in Paligemma.
|
||||||
size_t prefix_end;
|
size_t prefix_end;
|
||||||
|
|
||||||
KVCache& kv_cache;
|
KVCachePtr kv_cache;
|
||||||
|
|
||||||
// Previous token generated for this query, or the last prompt token. Will be
|
// Previous token generated for this query, or the last prompt token. Will be
|
||||||
// fed into the next Transformer() call.
|
// fed into the next Transformer() call.
|
||||||
|
|
@ -64,7 +64,7 @@ struct AllQueries {
|
||||||
|
|
||||||
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
||||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const hwy::Span<KVCache>& kv_caches) {
|
const hwy::Span<KVCachePtr>& kv_caches) {
|
||||||
per_query_.reserve(kv_caches.size());
|
per_query_.reserve(kv_caches.size());
|
||||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||||
|
|
@ -78,11 +78,16 @@ struct AllQueries {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
|
const hwy::Span<KVCache>& kv_caches)
|
||||||
|
: AllQueries(prompt, pos, prefix_end,
|
||||||
|
hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches))) {}
|
||||||
|
|
||||||
// Batch of queries with initial position set to zero. Causal attention
|
// Batch of queries with initial position set to zero. Causal attention
|
||||||
// is requested via empty or all-zero `prefix_end`.
|
// is requested via empty or all-zero `prefix_end`.
|
||||||
AllQueries(
|
AllQueries(
|
||||||
const hwy::Span<const PromptTokens>& prompts,
|
const hwy::Span<const PromptTokens>& prompts,
|
||||||
const hwy::Span<KVCache>& kv_caches,
|
const hwy::Span<KVCachePtr>& kv_caches,
|
||||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||||
|
|
@ -99,6 +104,13 @@ struct AllQueries {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AllQueries(
|
||||||
|
const hwy::Span<const PromptTokens>& prompts,
|
||||||
|
const hwy::Span<KVCache>& kv_caches,
|
||||||
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>())
|
||||||
|
: AllQueries(prompts, hwy::Span<KVCachePtr>(ToKVCachePtrs(kv_caches)),
|
||||||
|
prefix_end) {}
|
||||||
|
|
||||||
void Reserve(size_t size) { per_query_.reserve(size); }
|
void Reserve(size_t size) { per_query_.reserve(size); }
|
||||||
void Append(const PerQuery& query) { per_query_.push_back(query); }
|
void Append(const PerQuery& query) { per_query_.push_back(query); }
|
||||||
|
|
||||||
|
|
@ -156,7 +168,7 @@ class QBatch {
|
||||||
size_t PrefixEnd(size_t qi) const {
|
size_t PrefixEnd(size_t qi) const {
|
||||||
return queries_[QueryIdx(qi)].prefix_end;
|
return queries_[QueryIdx(qi)].prefix_end;
|
||||||
}
|
}
|
||||||
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
|
|
@ -54,4 +55,13 @@ KVCache KVCache::Copy() {
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches) {
|
||||||
|
std::vector<KVCachePtr> ptrs;
|
||||||
|
ptrs.reserve(kv_caches.size());
|
||||||
|
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||||
|
ptrs.push_back(KVCachePtr{.kv_cache = kv_caches[i].kv_cache});
|
||||||
|
}
|
||||||
|
return ptrs;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
#include "gemma/gemma_args.h" // InferenceArgs
|
#include "gemma/gemma_args.h" // InferenceArgs
|
||||||
|
|
@ -46,6 +47,15 @@ struct KVCache {
|
||||||
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// A non-owning view of a KVCache.
|
||||||
|
struct KVCachePtr {
|
||||||
|
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||||
|
MatPtrT<KV_t> kv_cache;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convenience function to create views into KVCaches.
|
||||||
|
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_CACHE_H_
|
||||||
|
|
|
||||||
|
|
@ -454,7 +454,7 @@ void TestRopeAndMulBy() {
|
||||||
x.Row(0)[i] = random_float();
|
x.Row(0)[i] = random_float();
|
||||||
}
|
}
|
||||||
|
|
||||||
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
const float qmul = ChooseQueryScale(config);
|
||||||
constexpr float kmul = 1.0f;
|
constexpr float kmul = 1.0f;
|
||||||
|
|
||||||
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);
|
MatStorageT<float> qexpected("qexpected", dim_qkv, ctx.allocator);
|
||||||
|
|
|
||||||
|
|
@ -284,6 +284,9 @@ class MatPtrT : public MatPtr {
|
||||||
public:
|
public:
|
||||||
using T = MatT;
|
using T = MatT;
|
||||||
|
|
||||||
|
// Default constructor for use with uninitialized views.
|
||||||
|
MatPtrT() = default;
|
||||||
|
|
||||||
// Called by `MatStorageT`.
|
// Called by `MatStorageT`.
|
||||||
MatPtrT(const char* name, Extents2D extents)
|
MatPtrT(const char* name, Extents2D extents)
|
||||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue