refactor: use build_vit for VAETKI vision encoder

This commit is contained in:
suhyun-hwang 2026-01-10 20:46:26 +09:00
parent 9d531ea9d5
commit d61a3f817c
1 changed files with 37 additions and 119 deletions

View File

@ -17,8 +17,6 @@ ggml_cgraph * clip_graph_vaetki::build() {
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
cb(inp, "inp_with_cls", -1);
ggml_tensor * inpL = inp;
// position IDs for 2D RoPE (patch tokens only)
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
@ -28,138 +26,58 @@ ggml_cgraph * clip_graph_vaetki::build() {
ggml_tensor * cls_cos = nullptr;
ggml_tensor * cls_sin = nullptr;
if (model.class_pos_emb) {
// class_pos_emb: [head_dim/2] -> concat to [head_dim]
ggml_tensor * cls_pos = ggml_concat(ctx0, model.class_pos_emb, model.class_pos_emb, 0);
cls_cos = ggml_cos(ctx0, cls_pos);
cls_sin = ggml_sin(ctx0, cls_pos);
}
if (model.pre_ln_w) {
inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
cb(inpL, "pre_ln", -1);
}
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) -> ggml_tensor * {
// split CLS and patch tokens
ggml_tensor * cur_cls = ggml_view_3d(ctx0, cur, d_head, n_head, 1,
ggml_row_size(cur->type, d_head),
ggml_row_size(cur->type, d_head * n_head), 0);
ggml_tensor * cur_patch = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos_patches,
ggml_row_size(cur->type, d_head),
ggml_row_size(cur->type, d_head * n_head),
ggml_row_size(cur->type, d_head * n_head));
for (int il = 0; il < n_layer; il++) {
const auto & layer = model.layers[il];
ggml_tensor * cur = inpL;
// apply RoPE to CLS token using class_pos_emb
if (cls_cos && cls_sin) {
ggml_tensor * cls_1 = ggml_view_3d(ctx0, cur_cls, d_head/2, n_head, 1,
ggml_row_size(cur_cls->type, d_head),
ggml_row_size(cur_cls->type, d_head * n_head), 0);
ggml_tensor * cls_2 = ggml_view_3d(ctx0, cur_cls, d_head/2, n_head, 1,
ggml_row_size(cur_cls->type, d_head),
ggml_row_size(cur_cls->type, d_head * n_head),
ggml_row_size(cur_cls->type, d_head/2));
ggml_tensor * cls_rot = ggml_concat(ctx0, ggml_neg(ctx0, cls_2), cls_1, 0);
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
cb(cur, "ln1", il);
// self-attention with 2D RoPE
{
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// split CLS and patch tokens for RoPE
ggml_tensor * Q_cls = ggml_view_3d(ctx0, Qcur, d_head, n_head, 1,
ggml_row_size(Qcur->type, d_head),
ggml_row_size(Qcur->type, d_head * n_head), 0);
ggml_tensor * K_cls = ggml_view_3d(ctx0, Kcur, d_head, n_head, 1,
ggml_row_size(Kcur->type, d_head),
ggml_row_size(Kcur->type, d_head * n_head), 0);
ggml_tensor * Q_patch = ggml_view_3d(ctx0, Qcur, d_head, n_head, n_pos_patches,
ggml_row_size(Qcur->type, d_head),
ggml_row_size(Qcur->type, d_head * n_head),
ggml_row_size(Qcur->type, d_head * n_head));
ggml_tensor * K_patch = ggml_view_3d(ctx0, Kcur, d_head, n_head, n_pos_patches,
ggml_row_size(Kcur->type, d_head),
ggml_row_size(Kcur->type, d_head * n_head),
ggml_row_size(Kcur->type, d_head * n_head));
// apply RoPE to CLS token using class_pos_emb
if (cls_cos && cls_sin) {
// rotate_half: split into two halves, negate second, swap order
ggml_tensor * Q_cls_1 = ggml_view_3d(ctx0, Q_cls, d_head/2, n_head, 1,
ggml_row_size(Q_cls->type, d_head),
ggml_row_size(Q_cls->type, d_head * n_head), 0);
ggml_tensor * Q_cls_2 = ggml_view_3d(ctx0, Q_cls, d_head/2, n_head, 1,
ggml_row_size(Q_cls->type, d_head),
ggml_row_size(Q_cls->type, d_head * n_head),
ggml_row_size(Q_cls->type, d_head/2));
ggml_tensor * Q_cls_rot = ggml_concat(ctx0, ggml_neg(ctx0, Q_cls_2), Q_cls_1, 0);
ggml_tensor * K_cls_1 = ggml_view_3d(ctx0, K_cls, d_head/2, n_head, 1,
ggml_row_size(K_cls->type, d_head),
ggml_row_size(K_cls->type, d_head * n_head), 0);
ggml_tensor * K_cls_2 = ggml_view_3d(ctx0, K_cls, d_head/2, n_head, 1,
ggml_row_size(K_cls->type, d_head),
ggml_row_size(K_cls->type, d_head * n_head),
ggml_row_size(K_cls->type, d_head/2));
ggml_tensor * K_cls_rot = ggml_concat(ctx0, ggml_neg(ctx0, K_cls_2), K_cls_1, 0);
// RoPE: x * cos + rotate_half(x) * sin
Q_cls = ggml_add(ctx0,
ggml_mul(ctx0, Q_cls, cls_cos),
ggml_mul(ctx0, Q_cls_rot, cls_sin));
K_cls = ggml_add(ctx0,
ggml_mul(ctx0, K_cls, cls_cos),
ggml_mul(ctx0, K_cls_rot, cls_sin));
}
// apply 2D RoPE to patch tokens
Q_patch = ggml_rope_multi(ctx0, Q_patch, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
K_patch = ggml_rope_multi(ctx0, K_patch, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
Qcur = ggml_concat(ctx0, Q_cls, Q_patch, 2);
Kcur = ggml_concat(ctx0, K_cls, K_patch, 2);
cb(Qcur, "Qcur_rope", il);
cb(Kcur, "Kcur_rope", il);
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
cur_cls = ggml_add(ctx0,
ggml_mul(ctx0, cur_cls, cls_cos),
ggml_mul(ctx0, cls_rot, cls_sin));
}
cur = ggml_add(ctx0, cur, inpL);
inpL = cur;
cb(cur, "ffn_inp", il);
// apply 2D RoPE to patch tokens
cur_patch = ggml_rope_multi(ctx0, cur_patch, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
cb(cur, "ln2", il);
return ggml_concat(ctx0, cur_cls, cur_patch, 2);
};
cur = build_ffn(cur,
layer.ff_up_w, layer.ff_up_b,
nullptr, nullptr,
layer.ff_down_w, layer.ff_down_b,
hparams.ffn_op, il);
cb(cur, "ffn_out", il);
ggml_tensor * cur = build_vit(
inp, n_pos,
norm_t,
hparams.ffn_op,
nullptr,
add_pos);
cur = ggml_add(ctx0, inpL, cur);
cb(cur, "layer_out", il);
inpL = cur;
}
cb(cur, "vit_out", -1);
// remove CLS token
ggml_tensor * embeddings = ggml_view_2d(ctx0, inpL,
ggml_tensor * embeddings = ggml_view_2d(ctx0, cur,
n_embd, n_pos_patches,
ggml_row_size(inpL->type, n_embd),
ggml_row_size(inpL->type, n_embd));
ggml_row_size(cur->type, n_embd),
ggml_row_size(cur->type, n_embd));
cb(embeddings, "patches_only", -1);
// merger