From a084d33e41cbdd14de897db1f616cc793f0c781f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 1 Dec 2025 11:46:47 -0800 Subject: [PATCH] Fix Gemma3 image: ensure A matrix is packed, preallocate Also ignore -2 tokens PiperOrigin-RevId: 838869988 --- gemma/activations.h | 25 +++++++++++++++++++++++++ gemma/run.cc | 10 +++++++--- gemma/vit.cc | 10 +++------- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index bebe902..a0627ae 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -66,6 +66,9 @@ struct AttentionActivations { ? batch_size * layer_config.heads * 3 : batch_size * layer_config.heads, allocator)), + vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)), + vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)), + vit_C(MatFactory("C2", batch_size, seq_len, allocator)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), att(MatFactory("att", batch_size, layer_config.heads * seq_len, @@ -101,6 +104,7 @@ struct AttentionActivations { q.AllocateAndAttachRowPtrs(row_ptrs); q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs); + vit_C.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } @@ -109,6 +113,10 @@ struct AttentionActivations { q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! + vit_Q.OverrideRows(batch_size); + // vit_K stays seq_len! + vit_C.OverrideRows(batch_size); + pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); @@ -123,6 +131,10 @@ struct AttentionActivations { MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. + MatStorageT vit_Q; + MatStorageT vit_K; + MatStorageT vit_C; + MatStorageT pre_att_rms_out; MatStorageT att; // attention vector MatStorageT att_out; // attention output @@ -150,6 +162,9 @@ struct AttentionActivationsPtrs { q = activations.q; q_bf = activations.q_bf; q_T = activations.q_T; + vit_Q = activations.vit_Q; + vit_K = activations.vit_K; + vit_C = activations.vit_C; pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; @@ -164,6 +179,11 @@ struct AttentionActivationsPtrs { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! + + vit_Q.OverrideRows(batch_size); + // vit_K stays seq_len! + vit_C.OverrideRows(batch_size); + pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); @@ -184,6 +204,11 @@ struct AttentionActivationsPtrs { MatPtrT q_bf; // Transposed query matrix for faster Q*K^T. MatPtrT q_T; + + MatPtrT vit_Q; + MatPtrT vit_K; + MatPtrT vit_C; + // Output of RMSNorm before attention, size batch_size x model_dim. MatPtrT pre_att_rms_out; // Attention scores computed from Q*K^T, size batch_size x (q_heads * diff --git a/gemma/run.cc b/gemma/run.cc index 7e2059f..90da090 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -129,11 +129,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // callback function invoked for each generated token. auto batch_stream_token = [&](size_t query_idx, size_t pos, int token, float) { - std::string token_text; - HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); - HWY_ASSERT(pos == abs_pos); ++abs_pos; + + std::string token_text; + if (!gemma.Tokenizer().Decode(std::vector{token}, &token_text)) { + if (token == -2) return true; // Gemma 3 ViT? + HWY_WARN("Failed to decode token %d.", token); + } + const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size; ++tokens_generated_this_turn; diff --git a/gemma/vit.cc b/gemma/vit.cc index 1be3123..31c6f0f 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -78,13 +78,9 @@ class VitAttention { const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - // Shift Q, K, VT to MatStorageT. - MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), - env_.ctx.allocator, MatPadding::kPacked); - MatStorageT K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator, - MatPadding::kPacked); - MatStorageT C("C2", Extents2D(num_tokens_, seq_len), - env_.ctx.allocator, MatPadding::kPacked); + MatPtrT& Q = activations_.attention.vit_Q; + MatPtrT& K = activations_.attention.vit_K; + MatPtrT& C = activations_.attention.vit_C; // Initialize att_out to zero prior to head loop. ZeroInit(activations_.attention.att_out);