diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 2a8bde0fd2..51005ab31a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1385,9 +1385,11 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + /* const float yarn_attn_factor = (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_DEEPSEEK2OCR) ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; + */ ggml_tensor * tmp; diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 3ee42036fd..6b80fe45f7 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/deepseekocr.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 51bcce1ebb..46577a2a4a 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -60,6 +60,11 @@ struct clip_hparams { std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + + // deepseek-ocr (sam) + int32_t sam_n_layer = 0; + int32_t sam_n_head = 0; + int32_t sam_n_embd = 0; // audio int32_t n_mel_bins = 0; // whisper preprocessor @@ -89,6 +94,21 @@ struct clip_hparams { warmup_image_size = n_tok_per_side * patch_size * cur_merge; // TODO: support warmup size for custom token numbers } + // sam vit deepseek-ocr + std::vector global_attn_indices() const { + return { 2, 5, 8, 11 }; + } + bool is_global_attn(int32_t layer) const { + const auto indices = global_attn_indices(); + + for (const auto & idx : indices) { + if (layer == idx) { + return true; + } + } + + return false; + } }; struct clip_layer { @@ -134,6 +154,10 @@ struct clip_layer { ggml_tensor * deepstack_fc1_b = nullptr; ggml_tensor * deepstack_fc2_w = nullptr; ggml_tensor * deepstack_fc2_b = nullptr; + + // sam rel_pos + ggml_tensor * rel_pos_w = nullptr; + ggml_tensor * rel_pos_h = nullptr; bool has_deepstack() const { return deepstack_fc1_w != nullptr; @@ -162,7 +186,8 @@ struct clip_model { ggml_tensor * post_ln_w; ggml_tensor * post_ln_b; - ggml_tensor * projection; // TODO: rename it to fc (fully connected layer) + ggml_tensor * fc_w; + ggml_tensor * fc_b; ggml_tensor * mm_fc_w; ggml_tensor * mm_fc_b; @@ -175,6 +200,8 @@ struct clip_model { ggml_tensor * mm_2_b = nullptr; ggml_tensor * image_newline = nullptr; + ggml_tensor * view_seperator = nullptr; + // Yi type models with mlp+normalization projection ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 @@ -266,6 +293,24 @@ struct clip_model { ggml_tensor * mm_4h_to_h_w = nullptr; ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + + // deepseek ocr sam + ggml_tensor * patch_embed_proj_w = nullptr; + ggml_tensor * patch_embed_proj_b = nullptr; + ggml_tensor * pos_embed = nullptr; + + ggml_tensor * neck_0_w; + ggml_tensor * neck_1_w; + ggml_tensor * neck_1_b; + ggml_tensor * neck_2_w; + ggml_tensor * neck_3_w; + ggml_tensor * neck_3_b; + ggml_tensor * net_2; + ggml_tensor * net_3; + + int32_t n_sam_layers = 12; // used by deepseek-ocr sam encoder + + std::vector sam_layers; bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 569fceb145..507c2d3407 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -453,104 +453,7 @@ ggml_tensor * clip_graph::build_vit( return inpL; } -static ggml_tensor * get_rel_pos( - ggml_context * ctx, - ggml_tensor * rel_pos, // [L, C] - ggml_tensor * indices, // [q_size, k_size] - int q_size, - int k_size - ) { - const int64_t C = rel_pos->ne[0]; // channels - const int64_t L = rel_pos->ne[1]; // length - GGML_ASSERT(indices != nullptr); - GGML_ASSERT(indices->type == GGML_TYPE_I32); - GGML_ASSERT(indices->ne[0] == k_size); - GGML_ASSERT(indices->ne[1] == q_size); - - const auto max_rel_dist = 2*std::max(q_size, k_size) - 1; - ggml_tensor * cur = rel_pos; - - if (max_rel_dist != L) { - // Linear interpolation - int64_t ne0 = cur->ne[0]; - int64_t ne1 = cur->ne[1]; - int64_t ne2 = cur->ne[2]; - int64_t ne3 = cur->ne[3]; - - cur = ggml_reshape_3d( - ctx, - ggml_cont(ctx, ggml_permute(ctx, cur, 1, 0, 2, 3)), - ne1, 1, ne0*ne2*ne3 - ); - cur = ggml_reshape_4d( - ctx, - ggml_interpolate( - ctx, - cur, - max_rel_dist, 1, ne0*ne2*ne3, 1, - ggml_scale_mode::GGML_SCALE_MODE_BILINEAR - ), - max_rel_dist, ne0, ne2, ne3 - ); - cur = ggml_cont(ctx, ggml_permute(ctx, cur, 1, 0, 2, 3)); - } - - // Flatten indices to 1D for ggml_get_rows - int qk = q_size * k_size; - - cur = ggml_reshape_3d( - ctx, - ggml_get_rows(ctx, cur, ggml_reshape_1d(ctx, indices, qk)), - C, k_size, q_size - ); - - return cur; // [C, k_size, q_size] - } - - // Implementation based on approach suggested by Acly - // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 - static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { - auto [c, w, h, b] = x->ne; - // same as - // x = ggml_win_part(m, x, window); - // x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]); - - int64_t px = (window - w % window) % window; - int64_t py = (window - h % window) % window; - int64_t npw = (w + px) / window; - int64_t nph = (h + py) / window; - - if (px > 0 || py > 0) { - x = ggml_pad(ctx, x, 0, int(px), int(py), 0); - } - x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b); - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); - x = ggml_reshape_4d(ctx, x, c, window, window, npw * nph * b); - return x; - } - - // Implementation based on approach suggested by Acly - // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 - static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) { - int64_t c = x->ne[0]; - // same as - // x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]); - // x = ggml_win_unpart(m, x, w, h, window); - - int64_t px = (window - w % window) % window; - int64_t py = (window - h % window) % window; - int64_t npw = (w + px) / window; - int64_t nph = (h + py) / window; - - int64_t b = x->ne[3] / (npw * nph); - x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - x = ggml_reshape_4d(m, x, c, w + px, h + py, b); - x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); - x = ggml_cont(m, x); - return x; - } // build the input after conv2d (inp_raw --> patches) // returns tensor with shape [n_embd, n_patches] @@ -850,219 +753,7 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale return cur; } - ggml_tensor * build_sam(ggml_tensor * inp_raw) { - const int n_embd = hparams.sam_n_embd; - const int n_layer = hparams.sam_n_layer; - const int n_heads = hparams.sam_n_head; - const int d_heads = n_embd / n_heads; - const int window = hparams.attn_window_size; - - ggml_tensor * inpL; - - inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw); - inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd)); - inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3)); - - ggml_tensor * rel_pos_indices_local; - ggml_tensor * rel_pos_indices_global; - - rel_pos_indices_local = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, window, window); - rel_pos_indices_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, inpL->ne[1], inpL->ne[2]); - ggml_set_name(rel_pos_indices_local, "rel_pos_indices_local"); - ggml_set_name(rel_pos_indices_global, "rel_pos_indices_global"); - ggml_set_input(rel_pos_indices_local); - ggml_set_input(rel_pos_indices_global); - - ggml_tensor * cur; - const auto tgt_size = inpL->ne[1]; - const auto str_size = model.pos_embed->ne[1]; - - if (str_size != tgt_size) { - ggml_tensor * old_pos_embed = nullptr; - old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3)); - ggml_tensor * new_pos_embed = ggml_interpolate( - ctx0, - old_pos_embed, - tgt_size, - tgt_size, - n_embd, - 1, - ggml_scale_mode::GGML_SCALE_MODE_BICUBIC - ); - new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3)); - cur = ggml_add(ctx0, inpL, new_pos_embed); - } else { - cur = ggml_add(ctx0, inpL, model.pos_embed); - } - - // loop over layers - for (int il = 0; il < n_layer; il++) { - auto & layer = model.sam_layers[il]; - ggml_tensor * shortcut = cur; - - // layernorm1 - cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); - - const int64_t w0 = cur->ne[1]; - const int64_t h0 = cur->ne[2]; - - ggml_tensor * indices; - - if (hparams.is_global_attn(il)) { - indices = rel_pos_indices_global; - } else { - // local attention layer - apply window partition - cur = window_partition(ctx0, cur, window); - indices = rel_pos_indices_local; - } - - const int64_t W = cur->ne[1]; - const int64_t H = cur->ne[2]; - // self-attention - { - const int B = cur->ne[3]; - - cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); - cur = ggml_add(ctx0, cur, layer.qkv_b); - cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape - cur = ggml_reshape_4d(ctx0, cur, n_embd, 3, W*H, B); - - ggml_tensor * Q; - ggml_tensor * K; - ggml_tensor * V; - - Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); - Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B); - - K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); - K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B); - - V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); - V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B); - - ggml_tensor * mask; - ggml_tensor * rw; - ggml_tensor * rh; - ggml_tensor * qr; - - rw = get_rel_pos(ctx0, layer.rel_pos_w, indices, W, W); // [W, W, C] - rh = get_rel_pos(ctx0, layer.rel_pos_h, indices, H, H); // [H, H, C] - qr = ggml_permute(ctx0, Q, 0, 2, 1, 3); - qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads); - - - rw = ggml_mul_mat (ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*n_heads, W, H, W] - rw = ggml_cont (ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*n_heads, H, W, W] - rw = ggml_reshape_4d(ctx0, rw, W, 1, W*H, n_heads*B); - rw = ggml_repeat_4d (ctx0, rw, W, H, W*H, n_heads*B); - rh = ggml_mul_mat (ctx0, rh, qr); // [B*n_heads, H, W, H] - rh = ggml_reshape_4d(ctx0, rh, 1, H, W*H, n_heads*B); - mask = ggml_add (ctx0, rw, rh); // [B*n_heads, H*W, H, W] - mask = ggml_reshape_4d(ctx0, mask, W*H, W*H, n_heads, B); - mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); - - float scale = 1.0f / sqrtf((float)d_heads); - - cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale, - il); // [B, H*W, n_embd] - cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B); - } - - if (hparams.is_global_attn(il) == false) { - // local attention layer - reverse window partition - cur = window_unpartition(ctx0, cur, w0, h0, window); - } - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, shortcut); - - ggml_tensor * inpFF = cur; - - // layernorm2 - cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); - - // ffn - cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, - layer.ff_down_b, hparams.ffn_op, il); - - // residual 2 - cur = ggml_add(ctx0, cur, inpFF); - cb(cur, "sam_layer_out", il); - } - - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - - cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); - cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - - cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); - cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - - cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1); - cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1); - cb(cur, "sam_output", -1); - - ggml_build_forward_expand(gf, cur); - return cur; - } - - ggml_tensor * build_dsocr_clip(ggml_tensor * patch_embeds) { - ggml_tensor * inp; - - inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); - inp = ggml_reshape_2d(ctx0, inp, inp->ne[0]*inp->ne[1], inp->ne[2]); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); - - ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings)); - - int n_pos = new_pos_embd->ne[1]; // +1 for [CLS] - const auto tgt_size = static_cast(std::sqrt(inp->ne[1])); - const auto src_size = static_cast(std::sqrt(n_pos - 1)); - - if (tgt_size != src_size) { - ggml_tensor * old_pos_embd; - ggml_tensor * cls_tok; - - old_pos_embd = ggml_view_2d( - ctx0, new_pos_embd, - new_pos_embd->ne[0], src_size * src_size, - ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0 - ); - cls_tok = ggml_view_2d( - ctx0, new_pos_embd, - new_pos_embd->ne[0], 1, - ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size - ); - new_pos_embd = ggml_interpolate(ctx0, - old_pos_embd, - tgt_size, - tgt_size, - new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC - ); - new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1); - new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1); - n_pos = tgt_size * tgt_size + 1; - } - - // add CLS token - inp = ggml_concat(ctx0, model.class_embedding, inp, 1); - - // for selecting learned pos embd, used by ViT - ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); - ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); - - ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK, - learned_pos_embd, nullptr); - - ggml_build_forward_expand(gf, cur); - - return cur; - } -}; + static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) { GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported"); @@ -1127,9 +818,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 builder = std::make_unique(ctx, img); } break; case PROJECTOR_TYPE_DEEPSEEKOCR: - { - res = graph.build_deepseek_ocr(); - } break; + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } diff --git a/tools/mtmd/models/deepseekocr.cpp b/tools/mtmd/models/deepseekocr.cpp new file mode 100644 index 0000000000..1691b97b67 --- /dev/null +++ b/tools/mtmd/models/deepseekocr.cpp @@ -0,0 +1,364 @@ +#include "models.h" + +ggml_tensor* clip_graph_deepseekocr::build_sam(ggml_tensor* inp_raw) +{ + const int n_embd = hparams.sam_n_embd; + const int n_layer = hparams.sam_n_layer; + const int n_heads = hparams.sam_n_head; + const int d_heads = n_embd / n_heads; + const int window = hparams.attn_window_size; + + ggml_tensor* inpL; + + inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw); + inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd)); + inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3)); + + ggml_tensor* rel_pos_indices_local; + ggml_tensor* rel_pos_indices_global; + + rel_pos_indices_local = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, window, window); + rel_pos_indices_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, inpL->ne[1], inpL->ne[2]); + ggml_set_name(rel_pos_indices_local, "rel_pos_indices_local"); + ggml_set_name(rel_pos_indices_global, "rel_pos_indices_global"); + ggml_set_input(rel_pos_indices_local); + ggml_set_input(rel_pos_indices_global); + + ggml_tensor* cur; + const auto tgt_size = inpL->ne[1]; + const auto str_size = model.pos_embed->ne[1]; + + if (str_size != tgt_size) + { + ggml_tensor* old_pos_embed = nullptr; + old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3)); + ggml_tensor* new_pos_embed = ggml_interpolate( + ctx0, + old_pos_embed, + tgt_size, + tgt_size, + n_embd, + 1, + ggml_scale_mode::GGML_SCALE_MODE_BICUBIC + ); + new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3)); + cur = ggml_add(ctx0, inpL, new_pos_embed); + } + else + { + cur = ggml_add(ctx0, inpL, model.pos_embed); + } + + // loop over layers + for (int il = 0; il < n_layer; il++) + { + auto& layer = model.sam_layers[il]; + ggml_tensor* shortcut = cur; + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + + const int64_t w0 = cur->ne[1]; + const int64_t h0 = cur->ne[2]; + + ggml_tensor* indices; + + if (hparams.is_global_attn(il)) + { + indices = rel_pos_indices_global; + } + else + { + // local attention layer - apply window partition + cur = window_partition(cur, window); + indices = rel_pos_indices_local; + } + + const int64_t W = cur->ne[1]; + const int64_t H = cur->ne[2]; + // self-attention + { + const int B = cur->ne[3]; + + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add(ctx0, cur, layer.qkv_b); + cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape + cur = ggml_reshape_4d(ctx0, cur, n_embd, 3, W * H, B); + + ggml_tensor* Q; + ggml_tensor* K; + ggml_tensor* V; + + Q = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 0 * cur->nb[1]); + Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W * H, B); + + K = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 1 * cur->nb[1]); + K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W * H, B); + + V = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 2 * cur->nb[1]); + V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W * H, B); + + ggml_tensor* mask; + ggml_tensor* rw; + ggml_tensor* rh; + ggml_tensor* qr; + + rw = get_rel_pos(layer.rel_pos_w, indices, W, W); // [W, W, C] + rh = get_rel_pos(layer.rel_pos_h, indices, H, H); // [H, H, C] + qr = ggml_permute(ctx0, Q, 0, 2, 1, 3); + qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads); + + + rw = ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*n_heads, W, H, W] + rw = ggml_cont(ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*n_heads, H, W, W] + rw = ggml_reshape_4d(ctx0, rw, W, 1, W * H, n_heads * B); + rw = ggml_repeat_4d(ctx0, rw, W, H, W * H, n_heads * B); + rh = ggml_mul_mat(ctx0, rh, qr); // [B*n_heads, H, W, H] + rh = ggml_reshape_4d(ctx0, rh, 1, H, W * H, n_heads * B); + mask = ggml_add(ctx0, rw, rh); // [B*n_heads, H*W, H, W] + mask = ggml_reshape_4d(ctx0, mask, W * H, W * H, n_heads, B); + mask = ggml_cast(ctx0, mask, GGML_TYPE_F16); + + float scale = 1.0f / sqrtf((float)d_heads); + + cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale, + il); // [B, H*W, n_embd] + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B); + } + + if (hparams.is_global_attn(il) == false) + { + // local attention layer - reverse window partition + cur = window_unpartition(cur, w0, h0, window); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, shortcut); + + ggml_tensor* inpFF = cur; + + // layernorm2 + cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + + // ffn + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, + layer.ff_down_b, hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, cur, inpFF); + cb(cur, "sam_layer_out", il); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + + cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + + cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + + cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1); + cb(cur, "sam_output", -1); + + ggml_build_forward_expand(gf, cur); + return cur; +} + + +ggml_cgraph* clip_graph_deepseekocr::build() +{ + //patch embedding + ggml_tensor* inp_raw = build_inp_raw(); + ggml_tensor* sam_out = build_sam(inp_raw); + ggml_tensor* clip_out = build_dsocr_clip(sam_out); + + int clip_n_patches = sam_out->ne[0] * sam_out->ne[1]; + + sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3)); + sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches); + clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]); + + ggml_tensor* cur; + cur = ggml_concat(ctx0, clip_out, sam_out, 0); + cur = ggml_reshape_2d(ctx0, cur, 2 * n_embd, clip_n_patches); + cur = ggml_cont(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.fc_w, cur); + cur = ggml_add(ctx0, cur, model.fc_b); + + const auto h = static_cast(std::sqrt(static_cast(cur->ne[1]))); + const auto w = h; + const auto n_dim = cur->ne[0]; + + ggml_tensor* imgnl; + ggml_tensor* vs; + + imgnl = ggml_repeat_4d(ctx0, model.image_newline, n_dim, 1, h, 1); + vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + cur = ggml_reshape_3d(ctx0, cur, n_dim, w, h); + cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w + 1) * h); + cur = ggml_concat(ctx0, cur, vs, 1); // (n_dim, h*(w+1) + 1) + + cb(cur, "dsocr_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; +} + +ggml_tensor* clip_graph_deepseekocr::build_dsocr_clip(ggml_tensor* patch_embeds) +{ + ggml_tensor* inp; + + inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); + inp = ggml_reshape_2d(ctx0, inp, inp->ne[0] * inp->ne[1], inp->ne[2]); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + ggml_tensor* new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, + ggml_dup_tensor(ctx0, model.position_embeddings)); + + int n_pos = new_pos_embd->ne[1]; // +1 for [CLS] + const auto tgt_size = static_cast(std::sqrt(inp->ne[1])); + const auto src_size = static_cast(std::sqrt(n_pos - 1)); + + if (tgt_size != src_size) + { + ggml_tensor* old_pos_embd; + ggml_tensor* cls_tok; + + old_pos_embd = ggml_view_2d( + ctx0, new_pos_embd, + new_pos_embd->ne[0], src_size * src_size, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0 + ); + cls_tok = ggml_view_2d( + ctx0, new_pos_embd, + new_pos_embd->ne[0], 1, + ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size + ); + new_pos_embd = ggml_interpolate(ctx0, + old_pos_embd, + tgt_size, + tgt_size, + new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC + ); + new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1); + new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1); + n_pos = tgt_size * tgt_size + 1; + } + + // add CLS token + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); + + // for selecting learned pos embd, used by ViT + ggml_tensor* positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); + ggml_tensor* learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); + + ggml_tensor* cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK, + learned_pos_embd, nullptr); + + ggml_build_forward_expand(gf, cur); + + return cur; +} + +ggml_tensor * clip_graph_deepseekocr::get_rel_pos( + ggml_tensor * rel_pos, // [L, C] + ggml_tensor * indices, // [q_size, k_size] + int q_size, + int k_size + ) { + const int64_t C = rel_pos->ne[0]; // channels + const int64_t L = rel_pos->ne[1]; // length + + GGML_ASSERT(indices != nullptr); + GGML_ASSERT(indices->type == GGML_TYPE_I32); + GGML_ASSERT(indices->ne[0] == k_size); + GGML_ASSERT(indices->ne[1] == q_size); + + const auto max_rel_dist = 2*std::max(q_size, k_size) - 1; + ggml_tensor * cur = rel_pos; + + if (max_rel_dist != L) { + // Linear interpolation + int64_t ne0 = cur->ne[0]; + int64_t ne1 = cur->ne[1]; + int64_t ne2 = cur->ne[2]; + int64_t ne3 = cur->ne[3]; + + cur = ggml_reshape_3d( + ctx0, + ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3)), + ne1, 1, ne0*ne2*ne3 + ); + cur = ggml_reshape_4d( + ctx0, + ggml_interpolate( + ctx0, + cur, + max_rel_dist, 1, ne0*ne2*ne3, 1, + ggml_scale_mode::GGML_SCALE_MODE_BILINEAR + ), + max_rel_dist, ne0, ne2, ne3 + ); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3)); + } + + // Flatten indices to 1D for ggml_get_rows + int qk = q_size * k_size; + + cur = ggml_reshape_3d( + ctx0, + ggml_get_rows(ctx0, cur, ggml_reshape_1d(ctx0, indices, qk)), + C, k_size, q_size + ); + + return cur; // [C, k_size, q_size] + } + + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 + ggml_tensor* clip_graph_deepseekocr::window_partition(ggml_tensor* x, int window) { + auto [c, w, h, b] = x->ne; + // same as + // x = ggml_win_part(m, x, window); + // x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + if (px > 0 || py > 0) { + x = ggml_pad(ctx0, x, 0, int(px), int(py), 0); + } + x = ggml_reshape_4d(ctx0, x, c * window, npw, window, nph * b); + x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(ctx0, x, c, window, window, npw * nph * b); + return x; + } + + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 + ggml_tensor* clip_graph_deepseekocr::window_unpartition(ggml_tensor* x, int w, int h, int window) { + int64_t c = x->ne[0]; + // same as + // x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]); + // x = ggml_win_unpart(m, x, w, h, window); + + int64_t px = (window - w % window) % window; + int64_t py = (window - h % window) % window; + int64_t npw = (w + px) / window; + int64_t nph = (h + py) / window; + + int64_t b = x->ne[3] / (npw * nph); + x = ggml_reshape_4d(ctx0, x, c * window, window, npw, nph * b); + x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(ctx0, x, c, w + px, h + py, b); + x = ggml_view_4d(ctx0, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + x = ggml_cont(ctx0, x); + return x; + } diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 4b35da259c..420ac4501d 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -56,3 +56,14 @@ struct clip_graph_whisper_enc : clip_graph { clip_graph_whisper_enc(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; }; + +struct clip_graph_deepseekocr : clip_graph { + clip_graph_deepseekocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * build_sam(ggml_tensor * inp_raw); + ggml_tensor * build_dsocr_clip(ggml_tensor * patch_embeds); + ggml_tensor * get_rel_pos(ggml_tensor * rel_pos, ggml_tensor * indices, int q_size, int k_size); + ggml_tensor * window_partition(ggml_tensor * x, int window); + ggml_tensor * window_unpartition(ggml_tensor * x, int w, int h, int window); +}; diff --git a/tools/mtmd/models/siglip.cpp b/tools/mtmd/models/siglip.cpp index ef094cfd0e..191694235d 100644 --- a/tools/mtmd/models/siglip.cpp +++ b/tools/mtmd/models/siglip.cpp @@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_siglip::build() { // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 const int scale_factor = model.hparams.n_merge; cur = build_patch_merge_permute(cur, scale_factor); - cur = ggml_mul_mat(ctx0, model.projection, cur); + cur = ggml_mul_mat(ctx0, model.fc_w, cur); } else if (proj_type == PROJECTOR_TYPE_LFM2) { // pixel unshuffle block