Split Activations into Griffin/Attention to reduce memory usage for attention-only tests.

PiperOrigin-RevId: 772025282
This commit is contained in:
Jan Wassenberg 2025-06-16 07:52:27 -07:00 committed by Copybara-Service
parent 2128d076db
commit 6773e4517c
8 changed files with 174 additions and 141 deletions

View File

@ -360,6 +360,7 @@ cc_test(
deps = [
":allocator",
":basics",
":configs",
":gemma_lib",
":mat",
":ops",

View File

@ -31,25 +31,46 @@
namespace gcpp {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
struct GriffinActivations {
GriffinActivations(const ModelConfig& config, size_t batch_size,
MatPadding pad)
: griffin_x("griffin_x", Extents2D(batch_size, config.model_dim), pad),
griffin_y("griffin_y", Extents2D(batch_size, config.model_dim), pad),
griffin_gate_x("griffin_gate_x",
Extents2D(batch_size, config.model_dim), pad),
griffin_multiplier("griffin_mul",
Extents2D(batch_size, config.model_dim), pad) {}
void SetBatchSize(size_t batch_size) {
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
MatStorageT<float> griffin_x;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
};
struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
}
}
struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, MatPadding pad,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: weights_config(config),
layer_config(config.layer_configs[0]),
div_seq_len(static_cast<uint32_t>(seq_len)),
is_griffin(config.model == Model::GRIFFIN_2B),
: config(config),
x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache.
q("q",
@ -57,36 +78,16 @@ struct Activations {
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim),
pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
pad),
pre_att_rms_out("pre_att_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_),
Extents2D(batch_size, config.model_dim), pad),
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad),
att_out(
"att_out",
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
pad_),
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_),
pre_ffw_rms_out("pre_ffw_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
griffin_x("griffin_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
pad_),
griffin_y("griffin_y",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
pad_),
griffin_gate_x(
"griffin_gate_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
griffin_multiplier(
"griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
pad),
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad),
inv_timescale(
CreateInvTimescale(layer_config.qkv_dim,
@ -95,16 +96,73 @@ struct Activations {
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
1000000.0)),
div_seq_len(static_cast<uint32_t>(seq_len)),
query_scale(ChooseQueryScale(config)) {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs);
q.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
}
bool IsGlobalLayer(size_t layer_idx) const {
return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor();
}
const ModelConfig& config;
MatStorageT<float> q; // query
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
// Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums;
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
hwy::Divisor div_seq_len;
float query_scale;
};
struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]),
x("x", Extents2D(batch_size, config.model_dim), pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
pre_ffw_rms_out("pre_ffw_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
attention(config, layer_config, batch_size, seq_len, pad_, row_ptrs) {
HWY_ASSERT(batch_size != 0);
if (config.model == Model::GRIFFIN_2B) {
griffin = std::make_unique<GriffinActivations>(config, batch_size, pad_);
}
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs);
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
@ -115,67 +173,35 @@ struct Activations {
void SetBatchSize(size_t batch_size) {
PROFILER_ZONE("SetBatchSize");
x.OverrideRows(batch_size);
q.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size);
C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size);
if (is_griffin) {
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
attention.SetBatchSize(batch_size);
if (griffin) {
griffin->SetBatchSize(batch_size);
}
}
bool IsGlobalLayer(size_t layer_idx) const {
return weights_config.attention_window_sizes[layer_idx] ==
div_seq_len.GetDivisor();
}
const ModelConfig& weights_config;
const LayerConfig& layer_config;
hwy::Divisor div_seq_len;
bool is_griffin;
const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd;
MatStorageT<float> x; // input
MatStorageT<float> q; // query
MatStorageT<float> logits;
// Attention
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
// Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums;
// Gated FFW
MatStorageT<BF16> pre_ffw_rms_out;
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it
MatStorageT<float> C2;
MatStorageT<BF16> ffw_out;
// Griffin
MatStorageT<float> griffin_x;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
float query_scale;
AttentionActivations attention;
std::unique_ptr<GriffinActivations> griffin;
};
} // namespace gcpp

View File

@ -20,7 +20,6 @@
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -66,14 +65,14 @@ template <typename U>
static void PositionalEncodingQK(U* qk, const size_t qkv_dim,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
const Activations& activations,
const AttentionActivations& activations,
const size_t pos, const float mul = 1.0f) {
const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1();
bool is_global_layer = activations.IsGlobalLayer(layer_idx);
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && IsVLM(activations.weights_config.model)) {
if (is_global_layer && IsVLM(activations.config.model)) {
inv_timescale = activations.inv_timescale_global.PackedScale1();
}
// PostQKType::Rope
@ -118,10 +117,10 @@ void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const Activations& activations, float* HWY_RESTRICT att,
const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const float att_cap = activations.weights_config.att_cap;
const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
@ -155,7 +154,7 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch,
AttentionActivations& activations, QBatch& qbatch,
NestedPools& pools) {
PROFILER_ZONE("Gen.Attention.DotSoftmax");
const hwy::Divisor div_qbatch(qbatch.Size());
@ -190,7 +189,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos =
StartPos(pos, activations.weights_config, layer_idx);
StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
@ -241,7 +240,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations,
AttentionActivations& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV");
@ -306,7 +305,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) {
AttentionActivations& activations,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.SumHeads");
const LayerConfig& layer_config = layer.layer_config;
// att_weights and att_out are concatenated heads, each of length
@ -324,8 +324,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
}
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
QBatch& qbatch, MatMulEnv& env, int flags) {
const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,

View File

@ -32,17 +32,18 @@ namespace gcpp {
const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const Activations& activations, float* HWY_RESTRICT att, \
const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
Activations& activations, QBatch& qbatch, \
NestedPools& pools); \
AttentionActivations& activations, \
QBatch& qbatch, NestedPools& pools); \
\
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, Activations& activations, \
QBatch& qbatch, MatMulEnv& env, int flags); \
const LayerWeightsPtrs& layer, \
AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -65,14 +65,15 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env,
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env,
/*flags=*/0);
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// so map `layer` to the Griffin layer index.
const size_t griffin_layer =
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
env);
}
@ -86,15 +87,15 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.pre_att_rms_out);
activations.attention.pre_att_rms_out);
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
qbatch, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.att_sums);
activations.attention.att_sums);
ResidualConnection(activations.att_sums, activations.x, layer,
ResidualConnection(activations.attention.att_sums, activations.x, layer,
/*is_attention=*/true);
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
@ -470,7 +471,7 @@ static void GenerateT(const ModelConfig& config,
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
}
HWY_ASSERT(prefill_tokens < seq_len);
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway.

View File

@ -60,20 +60,21 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const size_t num_interleaved = num_tokens * qbatch.Size();
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
GriffinActivations& griffin = *activations.griffin;
// X / Y linear layers.
// TODO: MatMul
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
CallUpcastedSame(
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
[&](const auto* wx, const auto* wy) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
TwoMatVecAdd(
*wx, *wy, 0, model_dim, model_dim,
activations.pre_att_rms_out.Row(r),
activations.attention.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
@ -87,7 +88,7 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
@ -124,10 +125,10 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
float* HWY_RESTRICT y = activations.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi);
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
float* HWY_RESTRICT rnn_state =
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
@ -175,9 +176,9 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
} // interleaved_idx
// Final linear layer.
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
activations.att_sums);
activations.attention.att_sums);
} // GriffinRecurrent
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -56,10 +56,10 @@ class VitAttention {
// Computes Q, K, V for all heads, stored in activations_.q.
HWY_NOINLINE void ComputeQKV() {
PROFILER_ZONE("Gen.VitAttention.QKV");
auto& qkv = activations_.q;
auto& qkv = activations_.attention.q;
HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w,
CallMatMul(activations_.attention.pre_att_rms_out, layer_.vit.qkv_einsum_w,
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
}
@ -69,7 +69,7 @@ class VitAttention {
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len =
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
@ -82,12 +82,13 @@ class VitAttention {
MatPadding::kPacked);
// Initialize att_out to zero prior to head loop.
ZeroInit(activations_.att_out);
ZeroInit(activations_.attention.att_out);
for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim;
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
@ -95,8 +96,8 @@ class VitAttention {
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task;
float* HWY_RESTRICT k =
activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim;
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
});
@ -111,10 +112,10 @@ class VitAttention {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
size_t token = task;
float* HWY_RESTRICT att_out =
activations_.att_out.Row(token) + head * qkv_dim;
activations_.attention.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v =
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}
});
@ -126,7 +127,7 @@ class VitAttention {
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len =
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
@ -137,24 +138,24 @@ class VitAttention {
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Row(token) + head * 3 * qkv_dim;
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Row(token) + head * seq_len;
activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k =
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
float* HWY_RESTRICT k = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Row(token) + head * qkv_dim;
activations_.attention.att_out.Row(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v =
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
@ -168,8 +169,8 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, env_,
activations_.att_sums);
CallMatMul(activations_.attention.att_out, layer_.vit.attn_out_w, bias,
env_, activations_.attention.att_sums);
}
public:
@ -184,7 +185,7 @@ class VitAttention {
HWY_INLINE void operator()() {
ComputeQKV();
if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) {
DotSoftmaxWeightedSumMatrix();
} else {
DotSoftmaxWeightedSum();
@ -233,7 +234,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) {
const size_t model_dim = activations.weights_config.model_dim;
const size_t model_dim = activations.attention.config.model_dim;
auto type = layer.layer_config.type;
HWY_DASSERT(type == LayerAttentionType::kVit);
(void)type;
@ -246,14 +247,14 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
// y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out
LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias,
activations.pre_att_rms_out);
activations.attention.pre_att_rms_out);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums
VitAttention(num_tokens, layer_idx, activations, layer, env)();
// x = out["+sa"] = x + y
AddFromBatched(activations.att_sums, x);
AddFromBatched(activations.attention.att_sums, x);
// y = nn.LayerNorm()(x)
// y ~ pre_ffw_rms_out

View File

@ -32,6 +32,7 @@
#include <vector>
#include "gemma/activations.h" // ChooseQueryScale
#include "gemma/configs.h"
#include "util/allocator.h"
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
@ -400,7 +401,7 @@ void TestRopeAndMulBy() {
x.Row(0)[i] = random_float();
}
const float qmul = ChooseQueryScale(config);
const float qmul = AttentionActivations::ChooseQueryScale(config);
const float kmul = 1.0;
MatStorageT<float> qexpected("qexpected", dim_qkv);