Split Activations into Griffin/Attention to reduce memory usage for attention-only tests.

PiperOrigin-RevId: 772025282
This commit is contained in:
Jan Wassenberg 2025-06-16 07:52:27 -07:00 committed by Copybara-Service
parent 2128d076db
commit 6773e4517c
8 changed files with 174 additions and 141 deletions

View File

@ -360,6 +360,7 @@ cc_test(
deps = [ deps = [
":allocator", ":allocator",
":basics", ":basics",
":configs",
":gemma_lib", ":gemma_lib",
":mat", ":mat",
":ops", ":ops",

View File

@ -31,6 +31,30 @@
namespace gcpp { namespace gcpp {
struct GriffinActivations {
GriffinActivations(const ModelConfig& config, size_t batch_size,
MatPadding pad)
: griffin_x("griffin_x", Extents2D(batch_size, config.model_dim), pad),
griffin_y("griffin_y", Extents2D(batch_size, config.model_dim), pad),
griffin_gate_x("griffin_gate_x",
Extents2D(batch_size, config.model_dim), pad),
griffin_multiplier("griffin_mul",
Extents2D(batch_size, config.model_dim), pad) {}
void SetBatchSize(size_t batch_size) {
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
MatStorageT<float> griffin_x;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
};
struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation. // Returns the scale value to use for the query in the attention computation.
// Also called by ops_test. // Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) { static inline float ChooseQueryScale(const ModelConfig& config) {
@ -41,15 +65,12 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim)); return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
} }
struct Activations { AttentionActivations(
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, MatPadding pad,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: weights_config(config), : config(config),
layer_config(config.layer_configs[0]),
div_seq_len(static_cast<uint32_t>(seq_len)),
is_griffin(config.model == Model::GRIFFIN_2B),
x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
// and does not use an external KV cache. // and does not use an external KV cache.
q("q", q("q",
@ -57,36 +78,16 @@ struct Activations {
config.vocab_size == 0 config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim ? layer_config.heads * 3 * layer_config.qkv_dim
: layer_config.heads * layer_config.qkv_dim), : layer_config.heads * layer_config.qkv_dim),
pad_), pad),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
pre_att_rms_out("pre_att_rms_out", pre_att_rms_out("pre_att_rms_out",
Extents2D(batch_size, config.model_dim), pad_), Extents2D(batch_size, config.model_dim), pad),
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_), att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad),
att_out( att_out(
"att_out", "att_out",
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
pad_), pad),
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_), att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad),
pre_ffw_rms_out("pre_ffw_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
griffin_x("griffin_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
pad_),
griffin_y("griffin_y",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
pad_),
griffin_gate_x(
"griffin_gate_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
griffin_multiplier(
"griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
inv_timescale( inv_timescale(
CreateInvTimescale(layer_config.qkv_dim, CreateInvTimescale(layer_config.qkv_dim,
@ -95,16 +96,73 @@ struct Activations {
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
1000000.0)), 1000000.0)),
div_seq_len(static_cast<uint32_t>(seq_len)),
query_scale(ChooseQueryScale(config)) { query_scale(ChooseQueryScale(config)) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers. // For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but // If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call. // fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs);
q.AllocateAndAttachRowPtrs(row_ptrs); q.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
}
bool IsGlobalLayer(size_t layer_idx) const {
return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor();
}
const ModelConfig& config;
MatStorageT<float> q; // query
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
// Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums;
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
hwy::Divisor div_seq_len;
float query_scale;
};
struct Activations {
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: layer_config(config.layer_configs[0]),
x("x", Extents2D(batch_size, config.model_dim), pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
pre_ffw_rms_out("pre_ffw_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
attention(config, layer_config, batch_size, seq_len, pad_, row_ptrs) {
HWY_ASSERT(batch_size != 0);
if (config.model == Model::GRIFFIN_2B) {
griffin = std::make_unique<GriffinActivations>(config, batch_size, pad_);
}
// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs);
ffw_out.AllocateAndAttachRowPtrs(row_ptrs); ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
@ -115,67 +173,35 @@ struct Activations {
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
PROFILER_ZONE("SetBatchSize"); PROFILER_ZONE("SetBatchSize");
x.OverrideRows(batch_size); x.OverrideRows(batch_size);
q.OverrideRows(batch_size);
logits.OverrideRows(batch_size); logits.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size); C1.OverrideRows(batch_size);
C2.OverrideRows(batch_size); C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size);
if (is_griffin) { attention.SetBatchSize(batch_size);
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size); if (griffin) {
griffin_gate_x.OverrideRows(batch_size); griffin->SetBatchSize(batch_size);
griffin_multiplier.OverrideRows(batch_size);
} }
} }
bool IsGlobalLayer(size_t layer_idx) const {
return weights_config.attention_window_sizes[layer_idx] ==
div_seq_len.GetDivisor();
}
const ModelConfig& weights_config;
const LayerConfig& layer_config; const LayerConfig& layer_config;
hwy::Divisor div_seq_len;
bool is_griffin;
const Extents2D none_ = Extents2D(); const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd; const MatPadding pad_ = MatPadding::kOdd;
MatStorageT<float> x; // input MatStorageT<float> x; // input
MatStorageT<float> q; // query
MatStorageT<float> logits; MatStorageT<float> logits;
// Attention
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
// Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums;
// Gated FFW // Gated FFW
MatStorageT<BF16> pre_ffw_rms_out; MatStorageT<BF16> pre_ffw_rms_out;
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it MatStorageT<float> C1; // TODO: BF16 after Activation() supports it
MatStorageT<float> C2; MatStorageT<float> C2;
MatStorageT<BF16> ffw_out; MatStorageT<BF16> ffw_out;
// Griffin AttentionActivations attention;
MatStorageT<float> griffin_x; std::unique_ptr<GriffinActivations> griffin;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
float query_scale;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -20,7 +20,6 @@
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -66,14 +65,14 @@ template <typename U>
static void PositionalEncodingQK(U* qk, const size_t qkv_dim, static void PositionalEncodingQK(U* qk, const size_t qkv_dim,
const size_t layer_idx, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
const Activations& activations, const AttentionActivations& activations,
const size_t pos, const float mul = 1.0f) { const size_t pos, const float mul = 1.0f) {
const PostQKType& post_qk = layer.layer_config.post_qk; const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on. // qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1(); const float* inv_timescale = activations.inv_timescale.PackedScale1();
bool is_global_layer = activations.IsGlobalLayer(layer_idx); bool is_global_layer = activations.IsGlobalLayer(layer_idx);
// TODO: add a config flag instead of hardcoding the model. // TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && IsVLM(activations.weights_config.model)) { if (is_global_layer && IsVLM(activations.config.model)) {
inv_timescale = activations.inv_timescale_global.PackedScale1(); inv_timescale = activations.inv_timescale_global.PackedScale1();
} }
// PostQKType::Rope // PostQKType::Rope
@ -118,10 +117,10 @@ void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos, const size_t pos, const size_t start_pos, const size_t last_pos,
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
const Activations& activations, float* HWY_RESTRICT att, const AttentionActivations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) { float* HWY_RESTRICT att_out) {
const size_t qkv_dim = layer.layer_config.qkv_dim; const size_t qkv_dim = layer.layer_config.qkv_dim;
const float att_cap = activations.weights_config.att_cap; const float att_cap = activations.config.att_cap;
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const size_t seq_len = const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor()); static_cast<size_t>(activations.div_seq_len.GetDivisor());
@ -155,7 +154,7 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
NestedPools& pools) { NestedPools& pools) {
PROFILER_ZONE("Gen.Attention.DotSoftmax"); PROFILER_ZONE("Gen.Attention.DotSoftmax");
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
@ -190,7 +189,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// the range of cache positions to attend to. // the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos = const size_t start_pos =
StartPos(pos, activations.weights_config, layer_idx); StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos; size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi); const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) { if (prefix_end > 0 && prefix_end - 1 > last_pos) {
@ -241,7 +240,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
// Fills activations.q and writes to KV cache. // Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
Activations& activations, AttentionActivations& activations,
const QBatch& qbatch, const int flags, const QBatch& qbatch, const int flags,
MatMulEnv& env) { MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV"); PROFILER_ZONE("Gen.Attention.QKV");
@ -306,7 +305,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`). // head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) { AttentionActivations& activations,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.SumHeads"); PROFILER_ZONE("Gen.Attention.SumHeads");
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
@ -324,8 +324,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
} }
void GemmaAttention(size_t num_tokens, const size_t layer_idx, void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations, const LayerWeightsPtrs& layer,
QBatch& qbatch, MatMulEnv& env, int flags) { AttentionActivations& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,

View File

@ -32,17 +32,18 @@ namespace gcpp {
const size_t pos, const size_t start_pos, const size_t last_pos, \ const size_t pos, const size_t start_pos, const size_t last_pos, \
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \ float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ size_t layer_idx, const LayerWeightsPtrs& layer, \
const Activations& activations, float* HWY_RESTRICT att, \ const AttentionActivations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \ float* HWY_RESTRICT att_out); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
Activations& activations, QBatch& qbatch, \ AttentionActivations& activations, \
NestedPools& pools); \ QBatch& qbatch, NestedPools& pools); \
\ \
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, Activations& activations, \ const LayerWeightsPtrs& layer, \
QBatch& qbatch, MatMulEnv& env, int flags); \ AttentionActivations& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -65,14 +65,15 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer, const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) { Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env, GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env,
/*flags=*/0); /*flags=*/0);
} else { } else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
// so map `layer` to the Griffin layer index. // so map `layer` to the Griffin layer index.
const size_t griffin_layer = const size_t griffin_layer =
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx); activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch, GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
env); env);
} }
@ -86,15 +87,15 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale, RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.pre_att_rms_out); activations.attention.pre_att_rms_out);
Attention(layer_config.type, num_tokens, layer_idx, layer, activations, Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
qbatch, env); qbatch, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.att_sums); activations.attention.att_sums);
ResidualConnection(activations.att_sums, activations.x, layer, ResidualConnection(activations.attention.att_sums, activations.x, layer,
/*is_attention=*/true); /*is_attention=*/true);
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale, RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
@ -470,7 +471,7 @@ static void GenerateT(const ModelConfig& config,
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
} }
HWY_ASSERT(prefill_tokens < seq_len); HWY_ASSERT(prefill_tokens < seq_len);
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len)); HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have // Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway. // qi loops anyway.

View File

@ -60,20 +60,21 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const size_t num_interleaved = num_tokens * qbatch.Size(); const size_t num_interleaved = num_tokens * qbatch.Size();
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size())); const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
GriffinActivations& griffin = *activations.griffin;
// X / Y linear layers. // X / Y linear layers.
// TODO: MatMul // TODO: MatMul
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows()); HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows()); HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
CallUpcastedSame( CallUpcastedSame(
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
[&](const auto* wx, const auto* wy) { [&](const auto* wx, const auto* wy) {
for (size_t r = 0; r < num_interleaved; ++r) { for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = activations.griffin_y.Row(r); float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
float* HWY_RESTRICT x = activations.griffin_x.Row(r); float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
TwoMatVecAdd( TwoMatVecAdd(
*wx, *wy, 0, model_dim, model_dim, *wx, *wy, 0, model_dim, model_dim,
activations.pre_att_rms_out.Row(r), activations.attention.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool); /*out0=*/x, /*out1=*/y, pool);
@ -87,7 +88,7 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(qi); float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
// cache[i] = input at time t-i. // cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth]; float* HWY_RESTRICT cache[kMaxConv1DWidth];
@ -124,10 +125,10 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(qi); float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
float* HWY_RESTRICT y = activations.griffin_y.Row(qi); float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi); float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi); float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
float* HWY_RESTRICT rnn_state = float* HWY_RESTRICT rnn_state =
qbatch.KV(qi).rglru_cache.Row(griffin_layer); qbatch.KV(qi).rglru_cache.Row(griffin_layer);
@ -175,9 +176,9 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
} // interleaved_idx } // interleaved_idx
// Final linear layer. // Final linear layer.
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w, CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
layer_weights->griffin.linear_out_biases.PackedScale1(), env, layer_weights->griffin.linear_out_biases.PackedScale1(), env,
activations.att_sums); activations.attention.att_sums);
} // GriffinRecurrent } // GriffinRecurrent
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -56,10 +56,10 @@ class VitAttention {
// Computes Q, K, V for all heads, stored in activations_.q. // Computes Q, K, V for all heads, stored in activations_.q.
HWY_NOINLINE void ComputeQKV() { HWY_NOINLINE void ComputeQKV() {
PROFILER_ZONE("Gen.VitAttention.QKV"); PROFILER_ZONE("Gen.VitAttention.QKV");
auto& qkv = activations_.q; auto& qkv = activations_.attention.q;
HWY_ASSERT(qkv.Rows() == num_tokens_); HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w, CallMatMul(activations_.attention.pre_att_rms_out, layer_.vit.qkv_einsum_w,
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv); layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
} }
@ -69,7 +69,7 @@ class VitAttention {
const size_t heads = layer_config_.heads; const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = const size_t seq_len =
static_cast<size_t>(activations_.div_seq_len.GetDivisor()); static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim)); const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
@ -82,12 +82,13 @@ class VitAttention {
MatPadding::kPacked); MatPadding::kPacked);
// Initialize att_out to zero prior to head loop. // Initialize att_out to zero prior to head loop.
ZeroInit(activations_.att_out); ZeroInit(activations_.attention.att_out);
for (size_t head = 0; head < heads; ++head) { for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t token = task; const size_t token = task;
float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim; float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working // TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim); MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
@ -95,8 +96,8 @@ class VitAttention {
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task; const size_t seq_idx = task;
float* HWY_RESTRICT k = float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim; head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
}); });
@ -111,10 +112,10 @@ class VitAttention {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
size_t token = task; size_t token = task;
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.att_out.Row(token) + head * qkv_dim; activations_.attention.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
} }
}); });
@ -126,7 +127,7 @@ class VitAttention {
const size_t heads = layer_config_.heads; const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = const size_t seq_len =
static_cast<size_t>(activations_.div_seq_len.GetDivisor()); static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim)); const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
@ -137,24 +138,24 @@ class VitAttention {
const size_t token = task / layer_config_.heads; const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att. // Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.q.Row(token) + head * 3 * qkv_dim; activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim); MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att = float* HWY_RESTRICT head_att =
activations_.att.Row(token) + head * seq_len; activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k = float* HWY_RESTRICT k = activations_.attention.q.Row(i) +
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim; head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k head_att[i] = Dot(q, k, qkv_dim); // score = q.k
} }
// SoftMax yields "probabilities" in head_att. // SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len); Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out. // Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.att_out.Row(token) + head * qkv_dim; activations_.attention.att_out.Row(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
} }
}); });
@ -168,8 +169,8 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads. // matmul output is the sum over heads.
CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, env_, CallMatMul(activations_.attention.att_out, layer_.vit.attn_out_w, bias,
activations_.att_sums); env_, activations_.attention.att_sums);
} }
public: public:
@ -184,7 +185,7 @@ class VitAttention {
HWY_INLINE void operator()() { HWY_INLINE void operator()() {
ComputeQKV(); ComputeQKV();
if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) {
DotSoftmaxWeightedSumMatrix(); DotSoftmaxWeightedSumMatrix();
} else { } else {
DotSoftmaxWeightedSum(); DotSoftmaxWeightedSum();
@ -233,7 +234,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
Activations& activations, MatMulEnv& env) { Activations& activations, MatMulEnv& env) {
const size_t model_dim = activations.weights_config.model_dim; const size_t model_dim = activations.attention.config.model_dim;
auto type = layer.layer_config.type; auto type = layer.layer_config.type;
HWY_DASSERT(type == LayerAttentionType::kVit); HWY_DASSERT(type == LayerAttentionType::kVit);
(void)type; (void)type;
@ -246,14 +247,14 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
// y = nn.LayerNorm()(x) // y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out // y ~ pre_att_rms_out
LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias, LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias,
activations.pre_att_rms_out); activations.attention.pre_att_rms_out);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums // y ~ att_sums
VitAttention(num_tokens, layer_idx, activations, layer, env)(); VitAttention(num_tokens, layer_idx, activations, layer, env)();
// x = out["+sa"] = x + y // x = out["+sa"] = x + y
AddFromBatched(activations.att_sums, x); AddFromBatched(activations.attention.att_sums, x);
// y = nn.LayerNorm()(x) // y = nn.LayerNorm()(x)
// y ~ pre_ffw_rms_out // y ~ pre_ffw_rms_out

View File

@ -32,6 +32,7 @@
#include <vector> #include <vector>
#include "gemma/activations.h" // ChooseQueryScale #include "gemma/activations.h" // ChooseQueryScale
#include "gemma/configs.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT #include "util/mat.h" // MatStorageT
@ -400,7 +401,7 @@ void TestRopeAndMulBy() {
x.Row(0)[i] = random_float(); x.Row(0)[i] = random_float();
} }
const float qmul = ChooseQueryScale(config); const float qmul = AttentionActivations::ChooseQueryScale(config);
const float kmul = 1.0; const float kmul = 1.0;
MatStorageT<float> qexpected("qexpected", dim_qkv); MatStorageT<float> qexpected("qexpected", dim_qkv);