mirror of https://github.com/google/gemma.cpp.git
Split Activations into Griffin/Attention to reduce memory usage for attention-only tests.
PiperOrigin-RevId: 772025282
This commit is contained in:
parent
2128d076db
commit
6773e4517c
|
|
@ -360,6 +360,7 @@ cc_test(
|
|||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":configs",
|
||||
":gemma_lib",
|
||||
":mat",
|
||||
":ops",
|
||||
|
|
|
|||
|
|
@ -31,25 +31,46 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
|
||||
config.layer_configs[0].heads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
|
||||
}
|
||||
struct GriffinActivations {
|
||||
GriffinActivations(const ModelConfig& config, size_t batch_size,
|
||||
MatPadding pad)
|
||||
: griffin_x("griffin_x", Extents2D(batch_size, config.model_dim), pad),
|
||||
griffin_y("griffin_y", Extents2D(batch_size, config.model_dim), pad),
|
||||
griffin_gate_x("griffin_gate_x",
|
||||
Extents2D(batch_size, config.model_dim), pad),
|
||||
griffin_multiplier("griffin_mul",
|
||||
Extents2D(batch_size, config.model_dim), pad) {}
|
||||
|
||||
struct Activations {
|
||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: weights_config(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
griffin_x.OverrideRows(batch_size);
|
||||
griffin_y.OverrideRows(batch_size);
|
||||
griffin_gate_x.OverrideRows(batch_size);
|
||||
griffin_multiplier.OverrideRows(batch_size);
|
||||
}
|
||||
|
||||
MatStorageT<float> griffin_x;
|
||||
MatStorageT<float> griffin_y;
|
||||
MatStorageT<float> griffin_gate_x;
|
||||
MatStorageT<float> griffin_multiplier;
|
||||
};
|
||||
|
||||
struct AttentionActivations {
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
// Also called by ops_test.
|
||||
static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
|
||||
config.layer_configs[0].heads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
|
||||
}
|
||||
|
||||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, MatPadding pad,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: config(config),
|
||||
|
||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||
// and does not use an external KV cache.
|
||||
q("q",
|
||||
|
|
@ -57,36 +78,16 @@ struct Activations {
|
|||
config.vocab_size == 0
|
||||
? layer_config.heads * 3 * layer_config.qkv_dim
|
||||
: layer_config.heads * layer_config.qkv_dim),
|
||||
pad_),
|
||||
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
|
||||
pad),
|
||||
|
||||
pre_att_rms_out("pre_att_rms_out",
|
||||
Extents2D(batch_size, config.model_dim), pad_),
|
||||
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_),
|
||||
Extents2D(batch_size, config.model_dim), pad),
|
||||
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad),
|
||||
att_out(
|
||||
"att_out",
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
|
||||
pad_),
|
||||
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_),
|
||||
|
||||
pre_ffw_rms_out("pre_ffw_rms_out",
|
||||
Extents2D(batch_size, config.model_dim), pad_),
|
||||
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
|
||||
|
||||
griffin_x("griffin_x",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
pad_),
|
||||
griffin_y("griffin_y",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
pad_),
|
||||
griffin_gate_x(
|
||||
"griffin_gate_x",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
|
||||
griffin_multiplier(
|
||||
"griffin_mul",
|
||||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_),
|
||||
pad),
|
||||
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad),
|
||||
|
||||
inv_timescale(
|
||||
CreateInvTimescale(layer_config.qkv_dim,
|
||||
|
|
@ -95,16 +96,73 @@ struct Activations {
|
|||
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
||||
1000000.0)),
|
||||
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
query_scale(ChooseQueryScale(config)) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
x.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
}
|
||||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
}
|
||||
|
||||
bool IsGlobalLayer(size_t layer_idx) const {
|
||||
return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor();
|
||||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
|
||||
MatStorageT<float> q; // query
|
||||
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
// Accumulation of attention outputs over heads
|
||||
MatStorageT<BF16> att_sums;
|
||||
|
||||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
hwy::Divisor div_seq_len;
|
||||
float query_scale;
|
||||
};
|
||||
|
||||
struct Activations {
|
||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: layer_config(config.layer_configs[0]),
|
||||
|
||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
|
||||
|
||||
pre_ffw_rms_out("pre_ffw_rms_out",
|
||||
Extents2D(batch_size, config.model_dim), pad_),
|
||||
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
|
||||
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
|
||||
|
||||
attention(config, layer_config, batch_size, seq_len, pad_, row_ptrs) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
if (config.model == Model::GRIFFIN_2B) {
|
||||
griffin = std::make_unique<GriffinActivations>(config, batch_size, pad_);
|
||||
}
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
x.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
ffw_out.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
|
|
@ -115,67 +173,35 @@ struct Activations {
|
|||
void SetBatchSize(size_t batch_size) {
|
||||
PROFILER_ZONE("SetBatchSize");
|
||||
x.OverrideRows(batch_size);
|
||||
q.OverrideRows(batch_size);
|
||||
logits.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
att_sums.OverrideRows(batch_size);
|
||||
|
||||
pre_ffw_rms_out.OverrideRows(batch_size);
|
||||
C1.OverrideRows(batch_size);
|
||||
C2.OverrideRows(batch_size);
|
||||
ffw_out.OverrideRows(batch_size);
|
||||
|
||||
if (is_griffin) {
|
||||
griffin_x.OverrideRows(batch_size);
|
||||
griffin_y.OverrideRows(batch_size);
|
||||
griffin_gate_x.OverrideRows(batch_size);
|
||||
griffin_multiplier.OverrideRows(batch_size);
|
||||
attention.SetBatchSize(batch_size);
|
||||
|
||||
if (griffin) {
|
||||
griffin->SetBatchSize(batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsGlobalLayer(size_t layer_idx) const {
|
||||
return weights_config.attention_window_sizes[layer_idx] ==
|
||||
div_seq_len.GetDivisor();
|
||||
}
|
||||
|
||||
const ModelConfig& weights_config;
|
||||
const LayerConfig& layer_config;
|
||||
hwy::Divisor div_seq_len;
|
||||
bool is_griffin;
|
||||
const Extents2D none_ = Extents2D();
|
||||
const MatPadding pad_ = MatPadding::kOdd;
|
||||
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<float> q; // query
|
||||
MatStorageT<float> logits;
|
||||
|
||||
// Attention
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
// Accumulation of attention outputs over heads
|
||||
MatStorageT<BF16> att_sums;
|
||||
|
||||
// Gated FFW
|
||||
MatStorageT<BF16> pre_ffw_rms_out;
|
||||
MatStorageT<float> C1; // TODO: BF16 after Activation() supports it
|
||||
MatStorageT<float> C2;
|
||||
MatStorageT<BF16> ffw_out;
|
||||
|
||||
// Griffin
|
||||
MatStorageT<float> griffin_x;
|
||||
MatStorageT<float> griffin_y;
|
||||
MatStorageT<float> griffin_gate_x;
|
||||
MatStorageT<float> griffin_multiplier;
|
||||
|
||||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
float query_scale;
|
||||
AttentionActivations attention;
|
||||
std::unique_ptr<GriffinActivations> griffin;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -66,14 +65,14 @@ template <typename U>
|
|||
static void PositionalEncodingQK(U* qk, const size_t qkv_dim,
|
||||
const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
const Activations& activations,
|
||||
const AttentionActivations& activations,
|
||||
const size_t pos, const float mul = 1.0f) {
|
||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||
bool is_global_layer = activations.IsGlobalLayer(layer_idx);
|
||||
// TODO: add a config flag instead of hardcoding the model.
|
||||
if (is_global_layer && IsVLM(activations.weights_config.model)) {
|
||||
if (is_global_layer && IsVLM(activations.config.model)) {
|
||||
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
||||
}
|
||||
// PostQKType::Rope
|
||||
|
|
@ -118,10 +117,10 @@ void SingleDotSoftmaxWeightedSum(
|
|||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const Activations& activations, float* HWY_RESTRICT att,
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const float att_cap = activations.weights_config.att_cap;
|
||||
const float att_cap = activations.config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
|
|
@ -155,7 +154,7 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
|||
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
NestedPools& pools) {
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
|
|
@ -190,7 +189,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t start_pos =
|
||||
StartPos(pos, activations.weights_config, layer_idx);
|
||||
StartPos(pos, activations.config, layer_idx);
|
||||
size_t last_pos = pos;
|
||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
||||
|
|
@ -241,7 +240,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
// Fills activations.q and writes to KV cache.
|
||||
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations,
|
||||
AttentionActivations& activations,
|
||||
const QBatch& qbatch, const int flags,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
|
|
@ -306,7 +305,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
AttentionActivations& activations,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
|
|
@ -324,8 +324,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|||
}
|
||||
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env, int flags) {
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||
|
|
|
|||
|
|
@ -32,17 +32,18 @@ namespace gcpp {
|
|||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const Activations& activations, float* HWY_RESTRICT att, \
|
||||
const AttentionActivations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
Activations& activations, QBatch& qbatch, \
|
||||
NestedPools& pools); \
|
||||
AttentionActivations& activations, \
|
||||
QBatch& qbatch, NestedPools& pools); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, Activations& activations, \
|
||||
QBatch& qbatch, MatMulEnv& env, int flags); \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
AttentionActivations& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -65,14 +65,15 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
|
|||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env,
|
||||
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
|
||||
env,
|
||||
/*flags=*/0);
|
||||
} else {
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
|
||||
// so map `layer` to the Griffin layer index.
|
||||
const size_t griffin_layer =
|
||||
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
|
||||
activations.attention.config.NumLayersOfTypeBefore(type, layer_idx);
|
||||
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
|
||||
env);
|
||||
}
|
||||
|
|
@ -86,15 +87,15 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
|
|||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||
activations.pre_att_rms_out);
|
||||
activations.attention.pre_att_rms_out);
|
||||
|
||||
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
|
||||
qbatch, env);
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||
activations.att_sums);
|
||||
activations.attention.att_sums);
|
||||
|
||||
ResidualConnection(activations.att_sums, activations.x, layer,
|
||||
ResidualConnection(activations.attention.att_sums, activations.x, layer,
|
||||
/*is_attention=*/true);
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
|
||||
|
|
@ -470,7 +471,7 @@ static void GenerateT(const ModelConfig& config,
|
|||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||
}
|
||||
HWY_ASSERT(prefill_tokens < seq_len);
|
||||
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
|
||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||
|
||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||
// qi loops anyway.
|
||||
|
|
|
|||
|
|
@ -60,20 +60,21 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
|||
|
||||
const size_t num_interleaved = num_tokens * qbatch.Size();
|
||||
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
|
||||
GriffinActivations& griffin = *activations.griffin;
|
||||
|
||||
// X / Y linear layers.
|
||||
// TODO: MatMul
|
||||
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
|
||||
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
|
||||
HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows());
|
||||
HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows());
|
||||
CallUpcastedSame(
|
||||
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
|
||||
[&](const auto* wx, const auto* wy) {
|
||||
for (size_t r = 0; r < num_interleaved; ++r) {
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
|
||||
float* HWY_RESTRICT y = griffin.griffin_y.Row(r);
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(r);
|
||||
TwoMatVecAdd(
|
||||
*wx, *wy, 0, model_dim, model_dim,
|
||||
activations.pre_att_rms_out.Row(r),
|
||||
activations.attention.pre_att_rms_out.Row(r),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||
/*out0=*/x, /*out1=*/y, pool);
|
||||
|
|
@ -87,7 +88,7 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
|||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
|
||||
|
||||
// cache[i] = input at time t-i.
|
||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||
|
|
@ -124,10 +125,10 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
|||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(qi);
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi);
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi);
|
||||
float* HWY_RESTRICT x = griffin.griffin_x.Row(qi);
|
||||
float* HWY_RESTRICT y = griffin.griffin_y.Row(qi);
|
||||
float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi);
|
||||
float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi);
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
|
||||
|
||||
|
|
@ -175,9 +176,9 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
|||
} // interleaved_idx
|
||||
|
||||
// Final linear layer.
|
||||
CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w,
|
||||
CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), env,
|
||||
activations.att_sums);
|
||||
activations.attention.att_sums);
|
||||
} // GriffinRecurrent
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
49
gemma/vit.cc
49
gemma/vit.cc
|
|
@ -56,10 +56,10 @@ class VitAttention {
|
|||
// Computes Q, K, V for all heads, stored in activations_.q.
|
||||
HWY_NOINLINE void ComputeQKV() {
|
||||
PROFILER_ZONE("Gen.VitAttention.QKV");
|
||||
auto& qkv = activations_.q;
|
||||
auto& qkv = activations_.attention.q;
|
||||
HWY_ASSERT(qkv.Rows() == num_tokens_);
|
||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w,
|
||||
CallMatMul(activations_.attention.pre_att_rms_out, layer_.vit.qkv_einsum_w,
|
||||
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
|
||||
}
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ class VitAttention {
|
|||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
|
||||
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
|
|
@ -82,12 +82,13 @@ class VitAttention {
|
|||
MatPadding::kPacked);
|
||||
|
||||
// Initialize att_out to zero prior to head loop.
|
||||
ZeroInit(activations_.att_out);
|
||||
ZeroInit(activations_.attention.att_out);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t token = task;
|
||||
float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim;
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
|
|
@ -95,8 +96,8 @@ class VitAttention {
|
|||
|
||||
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t seq_idx = task;
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim;
|
||||
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
|
||||
head * 3 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
|
|
@ -111,10 +112,10 @@ class VitAttention {
|
|||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
size_t token = task;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Row(token) + head * qkv_dim;
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v =
|
||||
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
|
|
@ -126,7 +127,7 @@ class VitAttention {
|
|||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
|
||||
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
|
|
@ -137,24 +138,24 @@ class VitAttention {
|
|||
const size_t token = task / layer_config_.heads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Row(token) + head * 3 * qkv_dim;
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Row(token) + head * seq_len;
|
||||
activations_.attention.att.Row(token) + head * seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
|
||||
float* HWY_RESTRICT k = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + qkv_dim;
|
||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Row(token) + head * qkv_dim;
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v =
|
||||
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
|
|
@ -168,8 +169,8 @@ class VitAttention {
|
|||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, env_,
|
||||
activations_.att_sums);
|
||||
CallMatMul(activations_.attention.att_out, layer_.vit.attn_out_w, bias,
|
||||
env_, activations_.attention.att_sums);
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -184,7 +185,7 @@ class VitAttention {
|
|||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
DotSoftmaxWeightedSumMatrix();
|
||||
} else {
|
||||
DotSoftmaxWeightedSum();
|
||||
|
|
@ -233,7 +234,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
|
|||
void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations, MatMulEnv& env) {
|
||||
const size_t model_dim = activations.weights_config.model_dim;
|
||||
const size_t model_dim = activations.attention.config.model_dim;
|
||||
auto type = layer.layer_config.type;
|
||||
HWY_DASSERT(type == LayerAttentionType::kVit);
|
||||
(void)type;
|
||||
|
|
@ -246,14 +247,14 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_att_rms_out
|
||||
LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias,
|
||||
activations.pre_att_rms_out);
|
||||
activations.attention.pre_att_rms_out);
|
||||
|
||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||
// y ~ att_sums
|
||||
VitAttention(num_tokens, layer_idx, activations, layer, env)();
|
||||
|
||||
// x = out["+sa"] = x + y
|
||||
AddFromBatched(activations.att_sums, x);
|
||||
AddFromBatched(activations.attention.att_sums, x);
|
||||
|
||||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_ffw_rms_out
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/activations.h" // ChooseQueryScale
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
|
@ -400,7 +401,7 @@ void TestRopeAndMulBy() {
|
|||
x.Row(0)[i] = random_float();
|
||||
}
|
||||
|
||||
const float qmul = ChooseQueryScale(config);
|
||||
const float qmul = AttentionActivations::ChooseQueryScale(config);
|
||||
const float kmul = 1.0;
|
||||
|
||||
MatStorageT<float> qexpected("qexpected", dim_qkv);
|
||||
|
|
|
|||
Loading…
Reference in New Issue