diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index ded8721199..a50006ca9d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -659,237 +659,44 @@ struct clip_graph { return gf; } - ggml_tensor * build_sam_enc(ggml_tensor * inp_raw) { - constexpr int enc_n_embd = 768; - constexpr int _depth = 12; - constexpr int enc_n_heads = 12; - constexpr int enc_d_heads = enc_n_embd / enc_n_heads; - - ggml_tensor * inpL; - - inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw); - inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, enc_n_embd)); - inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3)); - - ggml_tensor * cur; - const auto tgt_size = inpL->ne[1]; - const auto str_size = model.pos_embed->ne[1]; - if (str_size != tgt_size) { - ggml_tensor * old_pos_embed = nullptr; - old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3)); - ggml_tensor * new_pos_embed = ggml_interpolate( - ctx0, - old_pos_embed, - tgt_size, - tgt_size, - enc_n_embd, - 1, - ggml_scale_mode::GGML_SCALE_MODE_BICUBIC - ); - new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3)); - cur = ggml_add(ctx0, inpL, new_pos_embed); - } else { - cur = ggml_add(ctx0, inpL, model.pos_embed); - } - - // loop over layers - for (int il = 0; il < _depth; il++) { - auto & layer = model.sam_layers[il]; - ggml_tensor * shortcut = cur; - - // layernorm1 - cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); - - const int64_t w0 = cur->ne[1]; - const int64_t h0 = cur->ne[2]; - - if (hparams.is_global_attn(il) == false) { - // local attention layer - apply window partition - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 - //cur = ggml_win_part(ctx0, cur, 14); - cur = window_partition(ctx0, cur, 14); // TODO: make this configurable - } - - const int64_t W = cur->ne[1]; - const int64_t H = cur->ne[2]; - - // self-attention - { - const int B = cur->ne[3]; - - cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); - cur = ggml_add(ctx0, cur, layer.qkv_b); - cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape - cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W*H, B); - - ggml_tensor * Q; - ggml_tensor * K; - ggml_tensor * V; - - Q = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); - Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), enc_d_heads, enc_n_heads, W*H, B); - Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] - - K = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); - K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), enc_d_heads, enc_n_heads, W*H, B); - K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] - - V = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); - V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), enc_d_heads, enc_n_heads, W*H, B); - V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] - - ggml_tensor * mask; - ggml_tensor * rw; - ggml_tensor * rh; - ggml_tensor * qr; - - rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] - rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] - qr = ggml_reshape_4d(ctx0, Q, enc_d_heads, W, H, B*enc_n_heads); - - const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; - - rw = ggml_mul_mat (ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*enc_n_heads, W, H, W] - rw = ggml_cont (ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*enc_n_heads, H, W, W] - rw = ggml_reshape_4d(ctx0, rw, W, 1, W*H, enc_n_heads*B); - rw = ggml_repeat_4d (ctx0, rw, W, H, W*H, enc_n_heads*B); - rh = ggml_mul_mat (ctx0, rh, qr); // [B*enc_n_heads, H, W, H] - rh = ggml_reshape_4d(ctx0, rh, 1, H, W*H, enc_n_heads*B); - mask = ggml_add (ctx0, rw, rh); // [B*enc_n_heads, H*W, H, W] - mask = ggml_reshape_4d(ctx0, mask, W*H, W*H, enc_n_heads, B); - mask = ggml_pad (ctx0, mask, 0, WH_pad, 0, 0); - mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); - - float scale = 1.0f / sqrtf((float)enc_d_heads); - cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, enc_n_heads, enc_d_heads] - - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), enc_n_embd, W, H, B); - cur = ggml_mul_mat(ctx0, layer.o_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.o_b); - } - - if (hparams.is_global_attn(il) == false) { - // local attention layer - reverse window partition - cur = window_unpartition(ctx0, cur, w0, h0, 14); // TODO: make window size configurable - } - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, shortcut); - - ggml_tensor * inpFF = cur; - - // layernorm2 - cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); - - // ffn - cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, - layer.ff_down_b, hparams.ffn_op, il); - - // residual 2 - cur = ggml_add(ctx0, cur, inpFF); - cb(cur, "sam_layer_out", il); - } - - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - - const int out_chans = model.neck_0_w->ne[3]; - - cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1); - cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_1_w, model.neck_1_b, hparams.eps); - cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); - cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_3_w, model.neck_3_b, hparams.eps); - - cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1); - cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1); - cb(cur, "sam_output", -1); - - ggml_build_forward_expand(gf, cur); - return cur; - } - - ggml_tensor * sam_layer_norm_2d(ggml_context * ctx0, - ggml_tensor * layer, - int n_channels, - ggml_tensor * w, - ggml_tensor * b, - float eps) { - // LayerNorm2d - // normalize along channel dimmension - // TODO: better implementation - layer = ggml_permute(ctx0, ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), 2, 0, - 1, 3); - layer = ggml_cont(ctx0, layer); - - layer = - ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), layer), - ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer)); - - return layer; - } - ggml_cgraph * build_deepseek_ocr() { //patch embedding ggml_tensor * inp_raw = build_inp_raw(); - ggml_tensor * global_features_1 = build_sam_enc(inp_raw); - ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); + ggml_tensor * sam_out = build_sam(inp_raw); + ggml_tensor * clip_out = build_dsocr_clip(sam_out); - // FIXME remove n_patches is hardcoded + int clip_n_patches = sam_out->ne[0] * sam_out->ne[1]; - // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) - global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1, 1, 2, 0, 3)); - int clip_n_patches = global_features_1->ne[1] * global_features_1->ne[2]; - - // flatten 2nd and 3rd dims - global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches); - - // remove CLS token - global_features_2 = ggml_view_2d(ctx0, global_features_2, n_embd, clip_n_patches, - global_features_2->nb[1], global_features_2->nb[1]); + sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3)); + sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches); + clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]); - ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 0); - global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd,clip_n_patches); - global_features = ggml_cont(ctx0, global_features); - global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); - global_features = ggml_add(ctx0, global_features, model.fc_b); - - global_features = build_global_local_features(ctx0,global_features); - - cb(global_features, "dsocr_output", -1); - - ggml_build_forward_expand(gf, global_features); - return gf; - } - - // global_features: [n_dim, h*w] - // image_newline: [n_dim] - // view_separator: [n_dim] - - ggml_tensor * build_global_local_features(ggml_context * ctx0, - ggml_tensor * global_features) { - GGML_ASSERT(model.image_newline != nullptr); - GGML_ASSERT(model.view_seperator != nullptr); - - const auto h = static_cast(std::sqrt(static_cast(global_features->ne[1]))); - const auto w = h; - const auto n_dim = global_features->ne[0]; - ggml_tensor * cur; + cur = ggml_concat(ctx0, clip_out, sam_out, 0); + cur = ggml_reshape_2d(ctx0, cur, 2*n_embd,clip_n_patches); + cur = ggml_cont(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.fc_w, cur); + cur = ggml_add(ctx0, cur, model.fc_b); + + const auto h = static_cast(std::sqrt(static_cast(cur->ne[1]))); + const auto w = h; + const auto n_dim = cur->ne[0]; + ggml_tensor * imgnl; ggml_tensor * vs; - cur = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); imgnl = ggml_repeat_4d(ctx0, model.image_newline, n_dim, 1, h, 1); - cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w+1)*h); - cb(cur, "insert_imgnl", -1); vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + cur = ggml_reshape_3d(ctx0, cur, n_dim, w, h); + cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w+1)*h); cur = ggml_concat(ctx0, cur, vs, 1); // (n_dim, h*(w+1) + 1) - cb(cur, "insert_vs", -1); - return cur; + cb(cur, "dsocr_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; } - - ggml_cgraph * build_pixtral() { const int n_merge = hparams.n_merge; @@ -1541,62 +1348,6 @@ struct clip_graph { return gf; } - ggml_tensor * build_dp_ocr_clip(ggml_tensor * patch_embeds) { - GGML_ASSERT(model.class_embedding != nullptr); - GGML_ASSERT(model.position_embeddings != nullptr); - - ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); - - - inp = ggml_reshape_2d(ctx0, inp, inp->ne[0]*inp->ne[1], inp->ne[2]); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); - - ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings)); - - int n_pos = new_pos_embd->ne[1]; // +1 for [CLS] - const auto tgt_size = static_cast(std::sqrt(inp->ne[1])); - const auto src_size = static_cast(std::sqrt(n_pos - 1)); - - - if (tgt_size != src_size) { - //ggml_tensor * old_pos_embd = ggml_new_tensor_2d(ctx0, model.position_embeddings->type, model.position_embeddings->ne[0], str_size * str_size); - ggml_tensor * old_pos_embd = ggml_view_2d(ctx0, new_pos_embd, - new_pos_embd->ne[0], src_size * src_size, - ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0); - ggml_tensor * cls_tok = ggml_view_2d(ctx0, new_pos_embd, - new_pos_embd->ne[0], 1, - ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size); - new_pos_embd = ggml_interpolate(ctx0, - old_pos_embd, - tgt_size, - tgt_size, - new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC); - new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1); - //new_pos_embd = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embd, 2,1,0,3)); - new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1); - n_pos = tgt_size * tgt_size + 1; - } - - - - // add CLS token - inp = ggml_concat(ctx0, model.class_embedding, inp, 1); - - //TODO : check norm type for dp-ocr-clip - norm_type norm_t = NORM_TYPE_NORMAL; - - // for selecting learned pos embd, used by ViT - ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); - ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); - - ggml_tensor * cur = build_vit(inp, n_pos, norm_t, ffn_op_type::FFN_GELU_QUICK, - learned_pos_embd, nullptr); // shape [1024, 16, 16] - - ggml_build_forward_expand(gf, cur); - - return cur; - } - ggml_cgraph * build_llama4() { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); @@ -2500,44 +2251,6 @@ private: return inpL; } - // attn: [q_h*q_w, k_h*k_w] - // rel_h: [q_h, q_w, k_h] - // rel_w: [q_h, q_w, k_w] - - static ggml_tensor * add_rel_pos_inplace( - ggml_context * ctx, - ggml_tensor * attn, - ggml_tensor * rel_w, - ggml_tensor * rel_h - ) { - const int k_w = rel_w->ne[0]; - const int k_h = rel_h->ne[0]; - const int q_w = rel_h->ne[1]; - const int q_h = rel_h->ne[2]; - - GGML_ASSERT(q_w == rel_w->ne[1]); - GGML_ASSERT(q_h == rel_w->ne[2]); - GGML_ASSERT(attn->ne[0] == k_h*k_w); - GGML_ASSERT(attn->ne[1] == q_h*q_w); - - ggml_tensor *attn_4d = ggml_reshape_4d(ctx, attn, k_w, k_h, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_h_4d = ggml_reshape_4d(ctx, rel_h, 1, k_h, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d - - ggml_tensor *rel_w_4d = ggml_reshape_4d(ctx, rel_w, k_w, 1, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d - - ggml_tensor * result = ggml_add_inplace(ctx, attn_4d, ggml_add_inplace(ctx, rel_h_rep, rel_w_rep)); - result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); - - - return result; - } - - static ggml_tensor * get_rel_pos( ggml_context * ctx, ggml_tensor * rel_pos, // [L, C] @@ -2683,28 +2396,6 @@ private: return x; } - // build the input after conv2d (inp_raw --> patches) - // returns tensor with shape [n_embd, n_patches] - ggml_tensor * build_enc_inp(ggml_tensor * inp_raw, - const int enc_patch_size, - const int enc_n_patches, - const int enc_n_embd) { - GGML_ASSERT(model.patch_embed_proj_w != nullptr); - GGML_ASSERT(model.patch_embed_proj_b != nullptr); - // Image to Patch Embedding. - // ggml_tensor * inp_raw = build_inp_raw(); // sam shape = [1024, 1024, 3] - // patch_embed_proj_w shape = [768, 3, 16, 16] - ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embed_proj_w, inp_raw, enc_patch_size, enc_patch_size, 0, 0, - 1, 1); // [64, 64, 768] - inp = ggml_reshape_2d(ctx0, inp, enc_n_patches * enc_n_patches, enc_n_embd); // [4096, 768] - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); // [768, 4096] - inp = ggml_add(ctx0, inp, model.patch_embed_proj_b); - inp = ggml_cont(ctx0, inp); - inp = ggml_reshape_4d(ctx0, inp, enc_n_embd, enc_n_patches, enc_n_patches, 1); - cb(inp, "enc_patch_bias", -1); - return inp; - } - // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] ggml_tensor * build_inp() { @@ -3009,6 +2700,208 @@ private: return cur; } + ggml_tensor * build_sam(ggml_tensor * inp_raw) { + const int n_embd = 768; + const int _depth = 12; + const int n_heads = 12; + const int d_heads = n_embd / n_heads; + + ggml_tensor * inpL; + + inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw); + inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd)); + inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3)); + + ggml_tensor * cur; + const auto tgt_size = inpL->ne[1]; + const auto str_size = model.pos_embed->ne[1]; + + if (str_size != tgt_size) { + ggml_tensor * old_pos_embed = nullptr; + old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3)); + ggml_tensor * new_pos_embed = ggml_interpolate( + ctx0, + old_pos_embed, + tgt_size, + tgt_size, + n_embd, + 1, + ggml_scale_mode::GGML_SCALE_MODE_BICUBIC + ); + new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3)); + cur = ggml_add(ctx0, inpL, new_pos_embed); + } else { + cur = ggml_add(ctx0, inpL, model.pos_embed); + } + + // loop over layers + for (int il = 0; il < _depth; il++) { + auto & layer = model.sam_layers[il]; + ggml_tensor * shortcut = cur; + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + + const int64_t w0 = cur->ne[1]; + const int64_t h0 = cur->ne[2]; + + if (hparams.is_global_attn(il) == false) { + // local attention layer - apply window partition + cur = window_partition(ctx0, cur, 14); // TODO: make this configurable + } + + const int64_t W = cur->ne[1]; + const int64_t H = cur->ne[2]; + + // self-attention + { + const int B = cur->ne[3]; + + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add(ctx0, cur, layer.qkv_b); + cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape + cur = ggml_reshape_4d(ctx0, cur, n_embd, 3, W*H, B); + + ggml_tensor * Q; + ggml_tensor * K; + ggml_tensor * V; + + Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); + Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B); + Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] + + K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); + K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B); + K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] + + V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); + V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B); + V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] + + ggml_tensor * mask; + ggml_tensor * rw; + ggml_tensor * rh; + ggml_tensor * qr; + + rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] + rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] + qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads); + + const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; + + rw = ggml_mul_mat (ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*n_heads, W, H, W] + rw = ggml_cont (ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*n_heads, H, W, W] + rw = ggml_reshape_4d(ctx0, rw, W, 1, W*H, n_heads*B); + rw = ggml_repeat_4d (ctx0, rw, W, H, W*H, n_heads*B); + rh = ggml_mul_mat (ctx0, rh, qr); // [B*n_heads, H, W, H] + rh = ggml_reshape_4d(ctx0, rh, 1, H, W*H, n_heads*B); + mask = ggml_add (ctx0, rw, rh); // [B*n_heads, H*W, H, W] + mask = ggml_reshape_4d(ctx0, mask, W*H, W*H, n_heads, B); + mask = ggml_pad (ctx0, mask, 0, WH_pad, 0, 0); + mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); + + float scale = 1.0f / sqrtf((float)d_heads); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads] + + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B); + cur = ggml_mul_mat(ctx0, layer.o_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.o_b); + } + + if (hparams.is_global_attn(il) == false) { + // local attention layer - reverse window partition + cur = window_unpartition(ctx0, cur, w0, h0, 14); // TODO: make window size configurable + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, shortcut); + + ggml_tensor * inpFF = cur; + + // layernorm2 + cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + + // ffn + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, + layer.ff_down_b, hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, cur, inpFF); + cb(cur, "sam_layer_out", il); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + + cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + + cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + + cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1); + cb(cur, "sam_output", -1); + + ggml_build_forward_expand(gf, cur); + return cur; + } + + ggml_tensor * build_dsocr_clip(ggml_tensor * patch_embeds) { + ggml_tensor * inp; + + inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); + inp = ggml_reshape_2d(ctx0, inp, inp->ne[0]*inp->ne[1], inp->ne[2]); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings)); + + int n_pos = new_pos_embd->ne[1]; // +1 for [CLS] + const auto tgt_size = static_cast(std::sqrt(inp->ne[1])); + const auto src_size = static_cast(std::sqrt(n_pos - 1)); + + if (tgt_size != src_size) { + ggml_tensor * old_pos_embd; + ggml_tensor * cls_tok; + + old_pos_embd = ggml_view_2d( + ctx0, new_pos_embd, + new_pos_embd->ne[0], src_size * src_size, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0 + ); + cls_tok = ggml_view_2d( + ctx0, new_pos_embd, + new_pos_embd->ne[0], 1, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size + ); + new_pos_embd = ggml_interpolate(ctx0, + old_pos_embd, + tgt_size, + tgt_size, + new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC + ); + new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1); + new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1); + n_pos = tgt_size * tgt_size + 1; + } + + // add CLS token + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); + + // for selecting learned pos embd, used by ViT + ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); + ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); + + ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK, + learned_pos_embd, nullptr); // shape [1024, 16, 16] + + ggml_build_forward_expand(gf, cur); + + return cur; + } }; static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {