#include "models.h" // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * clip_graph_llava::build() { const int batch_size = 1; const int n_pos = n_patches + (model.class_embedding ? 1 : 0); GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported"); // Calculate the deepest feature layer based on hparams and projector type int max_feature_layer = n_layer; { // Get the index of the second to last layer; this is the default for models that have a llava projector int il_last = hparams.n_layer - 1; int deepest_feature_layer = -1; if (proj_type == PROJECTOR_TYPE_MINICPMV || proj_type == PROJECTOR_TYPE_GLM_EDGE) { il_last += 1; } // If we set explicit vision feature layers, only go up to the deepest one // NOTE: only used by granite-vision models for now for (const auto & feature_layer : hparams.vision_feature_layer) { if (feature_layer > deepest_feature_layer) { deepest_feature_layer = feature_layer; } } max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer; } ggml_tensor * inp = build_inp(); // concat class_embeddings and patch_embeddings if (model.class_embedding) { inp = ggml_concat(ctx0, inp, model.class_embedding, 1); } ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); ggml_set_name(positions, "positions"); ggml_set_input(positions); inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions)); ggml_tensor * inpL = inp; // pre-layernorm if (model.pre_ln_w) { inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1); cb(inpL, "pre_ln", -1); } std::vector embedding_stack; const auto & vision_feature_layer = hparams.vision_feature_layer; // loop over layers for (int il = 0; il < max_feature_layer; il++) { auto & layer = model.layers[il]; ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states // If this is an embedding feature layer, save the output. // NOTE: 0 index here refers to the input to the encoder. if (vision_feature_layer.find(il) != vision_feature_layer.end()) { embedding_stack.push_back(cur); } // layernorm1 cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); cb(cur, "layer_inp_normed", il); // self-attention { ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); if (layer.q_b) { Qcur = ggml_add(ctx0, Qcur, layer.q_b); } ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); if (layer.k_b) { Kcur = ggml_add(ctx0, Kcur, layer.k_b); } ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); if (layer.v_b) { Vcur = ggml_add(ctx0, Vcur, layer.v_b); } Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, nullptr, kq_scale, il); cb(cur, "attn_out", il); } // re-add the layer input, e.g., residual cur = ggml_add(ctx0, cur, inpL); inpL = cur; // inpL = residual, cur = hidden_states cb(cur, "ffn_inp", il); // layernorm2 cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); cb(cur, "ffn_inp_normed", il); // ffn cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b, layer.ff_down_w, layer.ff_down_b, hparams.ffn_op, il); cb(cur, "ffn_out", il); // residual 2 cur = ggml_add(ctx0, inpL, cur); cb(cur, "layer_out", il); inpL = cur; } // post-layernorm if (model.post_ln_w) { inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1); } ggml_tensor * embeddings = inpL; // process vision feature layers (used by granite) { // final layer is a vision feature layer if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) { embedding_stack.push_back(inpL); } // If feature layers are explicitly set, stack them (if we have multiple) if (!embedding_stack.empty()) { embeddings = embedding_stack[0]; for (size_t i = 1; i < embedding_stack.size(); i++) { embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); } } } // llava projector (also used by granite) if (hparams.has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); ggml_set_name(patches, "patches"); ggml_set_input(patches); // shape [1, 576, 1024] // ne is whcn, ne = [1024, 576, 1, 1] embeddings = ggml_get_rows(ctx0, embeddings, patches); // print_tensor_info(embeddings, "embeddings"); // llava projector if (proj_type == PROJECTOR_TYPE_MLP) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_gelu(ctx0, embeddings); if (model.mm_2_w) { embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); } } else if (proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); // First LayerNorm embeddings = ggml_norm(ctx0, embeddings, eps); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w), model.mm_1_b); // GELU activation embeddings = ggml_gelu(ctx0, embeddings); // Second linear layer embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_3_b); // Second LayerNorm embeddings = ggml_norm(ctx0, embeddings, eps); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), model.mm_4_b); } else if (proj_type == PROJECTOR_TYPE_LDP) { // MobileVLM projector int n_patch = 24; ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings); mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b); mlp_1 = ggml_gelu(ctx0, mlp_1); ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1); mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b); // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1] // block 1 ggml_tensor * block_1 = nullptr; { // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24] mlp_3 = ggml_permute(ctx0, mlp_3, 1, 0, 2, 3); mlp_3 = ggml_cont_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); // stride = 1, padding = 1, bias is nullptr block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); // layer norm // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] block_1 = ggml_norm(ctx0, block_1, eps); block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b); block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] // hardswish ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] // pointwise conv block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b); block_1 = ggml_relu(ctx0, block_1); block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b); block_1 = ggml_hardsigmoid(ctx0, block_1); // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1] block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); block_1 = ggml_mul(ctx0, block_1_hw, block_1); int w = block_1->ne[0], h = block_1->ne[1]; block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1); block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] block_1 = ggml_norm(ctx0, block_1, eps); block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b); block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] // residual block_1 = ggml_add(ctx0, mlp_3, block_1); } // block_2 { // stride = 2 block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1); // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] // layer norm block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] block_1 = ggml_norm(ctx0, block_1, eps); block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b); block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] // hardswish ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); // not sure the parameters is right for globalAvgPooling block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] // pointwise conv block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b); block_1 = ggml_relu(ctx0, block_1); block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1); block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b); block_1 = ggml_hardsigmoid(ctx0, block_1); // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); block_1 = ggml_mul(ctx0, block_1_hw, block_1); int w = block_1->ne[0], h = block_1->ne[1]; block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1); block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] block_1 = ggml_norm(ctx0, block_1, eps); block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b); block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]); // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1] } embeddings = block_1; } else if (proj_type == PROJECTOR_TYPE_LDPV2) { int n_patch = 24; ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); mlp_0 = ggml_gelu(ctx0, mlp_0); ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); // mlp_2 ne = [2048, 576, 1, 1] // // AVG Pool Layer 2*2, strides = 2 mlp_2 = ggml_permute(ctx0, mlp_2, 1, 0, 2, 3); // mlp_2 ne = [576, 2048, 1, 1] mlp_2 = ggml_cont_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); // mlp_2 ne [24, 24, 2048, 1] mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); // weight ne = [3, 3, 2048, 1] ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); peg_0 = ggml_add(ctx0, peg_0, mlp_2); peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); embeddings = peg_0; } else { GGML_ABORT("fatal error"); } } // glm projector else if (proj_type == PROJECTOR_TYPE_GLM_EDGE) { size_t gridsz = (size_t)sqrt(embeddings->ne[1]); embeddings = ggml_permute(ctx0,embeddings,1,0,2,3); embeddings = ggml_cont_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); // GLU { embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); embeddings = ggml_norm(ctx0, embeddings, eps); embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); embeddings = ggml_gelu_inplace(ctx0, embeddings); ggml_tensor * x = embeddings; embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); embeddings = ggml_swiglu_split(ctx0, embeddings, x); embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } // arrangement of BOI/EOI token embeddings // note: these embeddings are not present in text model, hence we cannot process them as text tokens // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53 { embeddings = ggml_concat(ctx0, model.mm_boi, embeddings, 1); // BOI embeddings = ggml_concat(ctx0, embeddings, model.mm_eoi, 1); // EOI } } else { GGML_ABORT("llava: unknown projector type"); } // build the graph ggml_build_forward_expand(gf, embeddings); return gf; }