mirror of https://github.com/google/gemma.cpp.git
Fix Gemma3 image: ensure A matrix is packed, preallocate
Also ignore -2 tokens PiperOrigin-RevId: 838869988
This commit is contained in:
parent
1564dd3111
commit
a084d33e41
|
|
@ -66,6 +66,9 @@ struct AttentionActivations {
|
|||
? batch_size * layer_config.heads * 3
|
||||
: batch_size * layer_config.heads,
|
||||
allocator)),
|
||||
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
|
||||
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
|
||||
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
|
||||
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
|
||||
config.model_dim, allocator)),
|
||||
att(MatFactory("att", batch_size, layer_config.heads * seq_len,
|
||||
|
|
@ -101,6 +104,7 @@ struct AttentionActivations {
|
|||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_T.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
}
|
||||
|
||||
|
|
@ -109,6 +113,10 @@ struct AttentionActivations {
|
|||
q_bf.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
|
||||
vit_Q.OverrideRows(batch_size);
|
||||
// vit_K stays seq_len!
|
||||
vit_C.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
|
|
@ -123,6 +131,10 @@ struct AttentionActivations {
|
|||
MatStorageT<BF16> q_bf;
|
||||
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||
|
||||
MatStorageT<float> vit_Q;
|
||||
MatStorageT<float> vit_K;
|
||||
MatStorageT<float> vit_C;
|
||||
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
MatStorageT<float> att_out; // attention output
|
||||
|
|
@ -150,6 +162,9 @@ struct AttentionActivationsPtrs {
|
|||
q = activations.q;
|
||||
q_bf = activations.q_bf;
|
||||
q_T = activations.q_T;
|
||||
vit_Q = activations.vit_Q;
|
||||
vit_K = activations.vit_K;
|
||||
vit_C = activations.vit_C;
|
||||
pre_att_rms_out = activations.pre_att_rms_out;
|
||||
att = activations.att;
|
||||
att_out = activations.att_out;
|
||||
|
|
@ -164,6 +179,11 @@ struct AttentionActivationsPtrs {
|
|||
q.OverrideRows(batch_size);
|
||||
q_bf.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
|
||||
vit_Q.OverrideRows(batch_size);
|
||||
// vit_K stays seq_len!
|
||||
vit_C.OverrideRows(batch_size);
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
att_out.OverrideRows(batch_size);
|
||||
|
|
@ -184,6 +204,11 @@ struct AttentionActivationsPtrs {
|
|||
MatPtrT<BF16> q_bf;
|
||||
// Transposed query matrix for faster Q*K^T.
|
||||
MatPtrT<BF16> q_T;
|
||||
|
||||
MatPtrT<float> vit_Q;
|
||||
MatPtrT<float> vit_K;
|
||||
MatPtrT<float> vit_C;
|
||||
|
||||
// Output of RMSNorm before attention, size batch_size x model_dim.
|
||||
MatPtrT<float> pre_att_rms_out;
|
||||
// Attention scores computed from Q*K^T, size batch_size x (q_heads *
|
||||
|
|
|
|||
10
gemma/run.cc
10
gemma/run.cc
|
|
@ -129,11 +129,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
// callback function invoked for each generated token.
|
||||
auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
|
||||
float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
|
||||
HWY_ASSERT(pos == abs_pos);
|
||||
++abs_pos;
|
||||
|
||||
std::string token_text;
|
||||
if (!gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
|
||||
if (token == -2) return true; // Gemma 3 ViT?
|
||||
HWY_WARN("Failed to decode token %d.", token);
|
||||
}
|
||||
|
||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
|
|
|
|||
10
gemma/vit.cc
10
gemma/vit.cc
|
|
@ -78,13 +78,9 @@ class VitAttention {
|
|||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
// Shift Q, K, VT to MatStorageT.
|
||||
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
|
||||
env_.ctx.allocator, MatPadding::kPacked);
|
||||
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim), env_.ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
|
||||
env_.ctx.allocator, MatPadding::kPacked);
|
||||
MatPtrT<float>& Q = activations_.attention.vit_Q;
|
||||
MatPtrT<float>& K = activations_.attention.vit_K;
|
||||
MatPtrT<float>& C = activations_.attention.vit_C;
|
||||
|
||||
// Initialize att_out to zero prior to head loop.
|
||||
ZeroInit(activations_.attention.att_out);
|
||||
|
|
|
|||
Loading…
Reference in New Issue