Fix Gemma3 image: ensure A matrix is packed, preallocate

Also ignore -2 tokens

PiperOrigin-RevId: 838869988
This commit is contained in:
Jan Wassenberg 2025-12-01 11:46:47 -08:00 committed by Copybara-Service
parent 1564dd3111
commit a084d33e41
3 changed files with 35 additions and 10 deletions

View File

@ -66,6 +66,9 @@ struct AttentionActivations {
? batch_size * layer_config.heads * 3
: batch_size * layer_config.heads,
allocator)),
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)),
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
@ -101,6 +104,7 @@ struct AttentionActivations {
q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}
@ -109,6 +113,10 @@ struct AttentionActivations {
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
@ -123,6 +131,10 @@ struct AttentionActivations {
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
MatStorageT<float> vit_Q;
MatStorageT<float> vit_K;
MatStorageT<float> vit_C;
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
@ -150,6 +162,9 @@ struct AttentionActivationsPtrs {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
vit_Q = activations.vit_Q;
vit_K = activations.vit_K;
vit_C = activations.vit_C;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
@ -164,6 +179,11 @@ struct AttentionActivationsPtrs {
q.OverrideRows(batch_size);
q_bf.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
@ -184,6 +204,11 @@ struct AttentionActivationsPtrs {
MatPtrT<BF16> q_bf;
// Transposed query matrix for faster Q*K^T.
MatPtrT<BF16> q_T;
MatPtrT<float> vit_Q;
MatPtrT<float> vit_K;
MatPtrT<float> vit_C;
// Output of RMSNorm before attention, size batch_size x model_dim.
MatPtrT<float> pre_att_rms_out;
// Attention scores computed from Q*K^T, size batch_size x (q_heads *

View File

@ -129,11 +129,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// callback function invoked for each generated token.
auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
float) {
std::string token_text;
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(pos == abs_pos);
++abs_pos;
std::string token_text;
if (!gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
if (token == -2) return true; // Gemma 3 ViT?
HWY_WARN("Failed to decode token %d.", token);
}
const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;

View File

@ -78,13 +78,9 @@ class VitAttention {
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Shift Q, K, VT to MatStorageT.
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
env_.ctx.allocator, MatPadding::kPacked);
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator,
MatPadding::kPacked);
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
env_.ctx.allocator, MatPadding::kPacked);
MatPtrT<float>& Q = activations_.attention.vit_Q;
MatPtrT<float>& K = activations_.attention.vit_K;
MatPtrT<float>& C = activations_.attention.vit_C;
// Initialize att_out to zero prior to head loop.
ZeroInit(activations_.attention.att_out);