diff --git a/gemma/activations.h b/gemma/activations.h index 5c3e99b..f666eb0 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -38,8 +38,11 @@ struct Activations { is_griffin(config.model == Model::GRIFFIN_2B), x("x", Extents2D(batch_size, config.model_dim), pad_), - q("q", Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), - pad_), + // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA + // and does not use an external KV cache. + q("q", Extents2D(batch_size, config.vocab_size == 0 ? + layer_config.heads * 3 * layer_config.qkv_dim : + layer_config.heads * layer_config.qkv_dim), pad_), logits("logits", Extents2D(batch_size, config.vocab_size), pad_), pre_att_rms_out("pre_att_rms_out",