quick and (potential) dirty merge with https://github.com/ggml-org/llama.cpp/pull/17909
This commit is contained in:
parent
e0e69fd3fb
commit
f95a6fe9f3
|
|
@ -1385,9 +1385,11 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||||
|
|
||||||
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
||||||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
||||||
|
/*
|
||||||
const float yarn_attn_factor = (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_DEEPSEEK2OCR)
|
const float yarn_attn_factor = (model.arch == LLM_ARCH_DEEPSEEK2 || model.arch == LLM_ARCH_DEEPSEEK2OCR)
|
||||||
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
|
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
|
||||||
: cparams.yarn_attn_factor;
|
: cparams.yarn_attn_factor;
|
||||||
|
*/
|
||||||
|
|
||||||
ggml_tensor * tmp;
|
ggml_tensor * tmp;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ add_library(mtmd
|
||||||
models/qwen3vl.cpp
|
models/qwen3vl.cpp
|
||||||
models/siglip.cpp
|
models/siglip.cpp
|
||||||
models/whisper-enc.cpp
|
models/whisper-enc.cpp
|
||||||
|
models/deepseekocr.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
set_target_properties(mtmd PROPERTIES
|
set_target_properties(mtmd PROPERTIES
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,11 @@ struct clip_hparams {
|
||||||
std::unordered_set<int32_t> vision_feature_layer;
|
std::unordered_set<int32_t> vision_feature_layer;
|
||||||
int32_t attn_window_size = 0;
|
int32_t attn_window_size = 0;
|
||||||
int32_t n_wa_pattern = 0;
|
int32_t n_wa_pattern = 0;
|
||||||
|
|
||||||
|
// deepseek-ocr (sam)
|
||||||
|
int32_t sam_n_layer = 0;
|
||||||
|
int32_t sam_n_head = 0;
|
||||||
|
int32_t sam_n_embd = 0;
|
||||||
|
|
||||||
// audio
|
// audio
|
||||||
int32_t n_mel_bins = 0; // whisper preprocessor
|
int32_t n_mel_bins = 0; // whisper preprocessor
|
||||||
|
|
@ -89,6 +94,21 @@ struct clip_hparams {
|
||||||
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
||||||
// TODO: support warmup size for custom token numbers
|
// TODO: support warmup size for custom token numbers
|
||||||
}
|
}
|
||||||
|
// sam vit deepseek-ocr
|
||||||
|
std::vector<int32_t> global_attn_indices() const {
|
||||||
|
return { 2, 5, 8, 11 };
|
||||||
|
}
|
||||||
|
bool is_global_attn(int32_t layer) const {
|
||||||
|
const auto indices = global_attn_indices();
|
||||||
|
|
||||||
|
for (const auto & idx : indices) {
|
||||||
|
if (layer == idx) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
|
|
@ -134,6 +154,10 @@ struct clip_layer {
|
||||||
ggml_tensor * deepstack_fc1_b = nullptr;
|
ggml_tensor * deepstack_fc1_b = nullptr;
|
||||||
ggml_tensor * deepstack_fc2_w = nullptr;
|
ggml_tensor * deepstack_fc2_w = nullptr;
|
||||||
ggml_tensor * deepstack_fc2_b = nullptr;
|
ggml_tensor * deepstack_fc2_b = nullptr;
|
||||||
|
|
||||||
|
// sam rel_pos
|
||||||
|
ggml_tensor * rel_pos_w = nullptr;
|
||||||
|
ggml_tensor * rel_pos_h = nullptr;
|
||||||
|
|
||||||
bool has_deepstack() const {
|
bool has_deepstack() const {
|
||||||
return deepstack_fc1_w != nullptr;
|
return deepstack_fc1_w != nullptr;
|
||||||
|
|
@ -162,7 +186,8 @@ struct clip_model {
|
||||||
ggml_tensor * post_ln_w;
|
ggml_tensor * post_ln_w;
|
||||||
ggml_tensor * post_ln_b;
|
ggml_tensor * post_ln_b;
|
||||||
|
|
||||||
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
|
ggml_tensor * fc_w;
|
||||||
|
ggml_tensor * fc_b;
|
||||||
ggml_tensor * mm_fc_w;
|
ggml_tensor * mm_fc_w;
|
||||||
ggml_tensor * mm_fc_b;
|
ggml_tensor * mm_fc_b;
|
||||||
|
|
||||||
|
|
@ -175,6 +200,8 @@ struct clip_model {
|
||||||
ggml_tensor * mm_2_b = nullptr;
|
ggml_tensor * mm_2_b = nullptr;
|
||||||
|
|
||||||
ggml_tensor * image_newline = nullptr;
|
ggml_tensor * image_newline = nullptr;
|
||||||
|
ggml_tensor * view_seperator = nullptr;
|
||||||
|
|
||||||
|
|
||||||
// Yi type models with mlp+normalization projection
|
// Yi type models with mlp+normalization projection
|
||||||
ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
|
ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
|
||||||
|
|
@ -266,6 +293,24 @@ struct clip_model {
|
||||||
ggml_tensor * mm_4h_to_h_w = nullptr;
|
ggml_tensor * mm_4h_to_h_w = nullptr;
|
||||||
ggml_tensor * mm_boi = nullptr;
|
ggml_tensor * mm_boi = nullptr;
|
||||||
ggml_tensor * mm_eoi = nullptr;
|
ggml_tensor * mm_eoi = nullptr;
|
||||||
|
|
||||||
|
// deepseek ocr sam
|
||||||
|
ggml_tensor * patch_embed_proj_w = nullptr;
|
||||||
|
ggml_tensor * patch_embed_proj_b = nullptr;
|
||||||
|
ggml_tensor * pos_embed = nullptr;
|
||||||
|
|
||||||
|
ggml_tensor * neck_0_w;
|
||||||
|
ggml_tensor * neck_1_w;
|
||||||
|
ggml_tensor * neck_1_b;
|
||||||
|
ggml_tensor * neck_2_w;
|
||||||
|
ggml_tensor * neck_3_w;
|
||||||
|
ggml_tensor * neck_3_b;
|
||||||
|
ggml_tensor * net_2;
|
||||||
|
ggml_tensor * net_3;
|
||||||
|
|
||||||
|
int32_t n_sam_layers = 12; // used by deepseek-ocr sam encoder
|
||||||
|
|
||||||
|
std::vector<clip_layer> sam_layers;
|
||||||
|
|
||||||
bool audio_has_avgpool() const {
|
bool audio_has_avgpool() const {
|
||||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||||
|
|
|
||||||
|
|
@ -453,104 +453,7 @@ ggml_tensor * clip_graph::build_vit(
|
||||||
return inpL;
|
return inpL;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_tensor * get_rel_pos(
|
|
||||||
ggml_context * ctx,
|
|
||||||
ggml_tensor * rel_pos, // [L, C]
|
|
||||||
ggml_tensor * indices, // [q_size, k_size]
|
|
||||||
int q_size,
|
|
||||||
int k_size
|
|
||||||
) {
|
|
||||||
const int64_t C = rel_pos->ne[0]; // channels
|
|
||||||
const int64_t L = rel_pos->ne[1]; // length
|
|
||||||
|
|
||||||
GGML_ASSERT(indices != nullptr);
|
|
||||||
GGML_ASSERT(indices->type == GGML_TYPE_I32);
|
|
||||||
GGML_ASSERT(indices->ne[0] == k_size);
|
|
||||||
GGML_ASSERT(indices->ne[1] == q_size);
|
|
||||||
|
|
||||||
const auto max_rel_dist = 2*std::max(q_size, k_size) - 1;
|
|
||||||
ggml_tensor * cur = rel_pos;
|
|
||||||
|
|
||||||
if (max_rel_dist != L) {
|
|
||||||
// Linear interpolation
|
|
||||||
int64_t ne0 = cur->ne[0];
|
|
||||||
int64_t ne1 = cur->ne[1];
|
|
||||||
int64_t ne2 = cur->ne[2];
|
|
||||||
int64_t ne3 = cur->ne[3];
|
|
||||||
|
|
||||||
cur = ggml_reshape_3d(
|
|
||||||
ctx,
|
|
||||||
ggml_cont(ctx, ggml_permute(ctx, cur, 1, 0, 2, 3)),
|
|
||||||
ne1, 1, ne0*ne2*ne3
|
|
||||||
);
|
|
||||||
cur = ggml_reshape_4d(
|
|
||||||
ctx,
|
|
||||||
ggml_interpolate(
|
|
||||||
ctx,
|
|
||||||
cur,
|
|
||||||
max_rel_dist, 1, ne0*ne2*ne3, 1,
|
|
||||||
ggml_scale_mode::GGML_SCALE_MODE_BILINEAR
|
|
||||||
),
|
|
||||||
max_rel_dist, ne0, ne2, ne3
|
|
||||||
);
|
|
||||||
cur = ggml_cont(ctx, ggml_permute(ctx, cur, 1, 0, 2, 3));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flatten indices to 1D for ggml_get_rows
|
|
||||||
int qk = q_size * k_size;
|
|
||||||
|
|
||||||
cur = ggml_reshape_3d(
|
|
||||||
ctx,
|
|
||||||
ggml_get_rows(ctx, cur, ggml_reshape_1d(ctx, indices, qk)),
|
|
||||||
C, k_size, q_size
|
|
||||||
);
|
|
||||||
|
|
||||||
return cur; // [C, k_size, q_size]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implementation based on approach suggested by Acly
|
|
||||||
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
|
|
||||||
static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) {
|
|
||||||
auto [c, w, h, b] = x->ne;
|
|
||||||
// same as
|
|
||||||
// x = ggml_win_part(m, x, window);
|
|
||||||
// x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]);
|
|
||||||
|
|
||||||
int64_t px = (window - w % window) % window;
|
|
||||||
int64_t py = (window - h % window) % window;
|
|
||||||
int64_t npw = (w + px) / window;
|
|
||||||
int64_t nph = (h + py) / window;
|
|
||||||
|
|
||||||
if (px > 0 || py > 0) {
|
|
||||||
x = ggml_pad(ctx, x, 0, int(px), int(py), 0);
|
|
||||||
}
|
|
||||||
x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b);
|
|
||||||
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));
|
|
||||||
x = ggml_reshape_4d(ctx, x, c, window, window, npw * nph * b);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implementation based on approach suggested by Acly
|
|
||||||
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
|
|
||||||
static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) {
|
|
||||||
int64_t c = x->ne[0];
|
|
||||||
// same as
|
|
||||||
// x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]);
|
|
||||||
// x = ggml_win_unpart(m, x, w, h, window);
|
|
||||||
|
|
||||||
int64_t px = (window - w % window) % window;
|
|
||||||
int64_t py = (window - h % window) % window;
|
|
||||||
int64_t npw = (w + px) / window;
|
|
||||||
int64_t nph = (h + py) / window;
|
|
||||||
|
|
||||||
int64_t b = x->ne[3] / (npw * nph);
|
|
||||||
x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b);
|
|
||||||
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
|
|
||||||
x = ggml_reshape_4d(m, x, c, w + px, h + py, b);
|
|
||||||
x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
|
|
||||||
x = ggml_cont(m, x);
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// build the input after conv2d (inp_raw --> patches)
|
// build the input after conv2d (inp_raw --> patches)
|
||||||
// returns tensor with shape [n_embd, n_patches]
|
// returns tensor with shape [n_embd, n_patches]
|
||||||
|
|
@ -850,219 +753,7 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * build_sam(ggml_tensor * inp_raw) {
|
|
||||||
const int n_embd = hparams.sam_n_embd;
|
|
||||||
const int n_layer = hparams.sam_n_layer;
|
|
||||||
const int n_heads = hparams.sam_n_head;
|
|
||||||
const int d_heads = n_embd / n_heads;
|
|
||||||
const int window = hparams.attn_window_size;
|
|
||||||
|
|
||||||
ggml_tensor * inpL;
|
|
||||||
|
|
||||||
inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw);
|
|
||||||
inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd));
|
|
||||||
inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3));
|
|
||||||
|
|
||||||
ggml_tensor * rel_pos_indices_local;
|
|
||||||
ggml_tensor * rel_pos_indices_global;
|
|
||||||
|
|
||||||
rel_pos_indices_local = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, window, window);
|
|
||||||
rel_pos_indices_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, inpL->ne[1], inpL->ne[2]);
|
|
||||||
ggml_set_name(rel_pos_indices_local, "rel_pos_indices_local");
|
|
||||||
ggml_set_name(rel_pos_indices_global, "rel_pos_indices_global");
|
|
||||||
ggml_set_input(rel_pos_indices_local);
|
|
||||||
ggml_set_input(rel_pos_indices_global);
|
|
||||||
|
|
||||||
ggml_tensor * cur;
|
|
||||||
const auto tgt_size = inpL->ne[1];
|
|
||||||
const auto str_size = model.pos_embed->ne[1];
|
|
||||||
|
|
||||||
if (str_size != tgt_size) {
|
|
||||||
ggml_tensor * old_pos_embed = nullptr;
|
|
||||||
old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3));
|
|
||||||
ggml_tensor * new_pos_embed = ggml_interpolate(
|
|
||||||
ctx0,
|
|
||||||
old_pos_embed,
|
|
||||||
tgt_size,
|
|
||||||
tgt_size,
|
|
||||||
n_embd,
|
|
||||||
1,
|
|
||||||
ggml_scale_mode::GGML_SCALE_MODE_BICUBIC
|
|
||||||
);
|
|
||||||
new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3));
|
|
||||||
cur = ggml_add(ctx0, inpL, new_pos_embed);
|
|
||||||
} else {
|
|
||||||
cur = ggml_add(ctx0, inpL, model.pos_embed);
|
|
||||||
}
|
|
||||||
|
|
||||||
// loop over layers
|
|
||||||
for (int il = 0; il < n_layer; il++) {
|
|
||||||
auto & layer = model.sam_layers[il];
|
|
||||||
ggml_tensor * shortcut = cur;
|
|
||||||
|
|
||||||
// layernorm1
|
|
||||||
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
|
|
||||||
|
|
||||||
const int64_t w0 = cur->ne[1];
|
|
||||||
const int64_t h0 = cur->ne[2];
|
|
||||||
|
|
||||||
ggml_tensor * indices;
|
|
||||||
|
|
||||||
if (hparams.is_global_attn(il)) {
|
|
||||||
indices = rel_pos_indices_global;
|
|
||||||
} else {
|
|
||||||
// local attention layer - apply window partition
|
|
||||||
cur = window_partition(ctx0, cur, window);
|
|
||||||
indices = rel_pos_indices_local;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t W = cur->ne[1];
|
|
||||||
const int64_t H = cur->ne[2];
|
|
||||||
// self-attention
|
|
||||||
{
|
|
||||||
const int B = cur->ne[3];
|
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
|
||||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
|
||||||
cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape
|
|
||||||
cur = ggml_reshape_4d(ctx0, cur, n_embd, 3, W*H, B);
|
|
||||||
|
|
||||||
ggml_tensor * Q;
|
|
||||||
ggml_tensor * K;
|
|
||||||
ggml_tensor * V;
|
|
||||||
|
|
||||||
Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]);
|
|
||||||
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B);
|
|
||||||
|
|
||||||
K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]);
|
|
||||||
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B);
|
|
||||||
|
|
||||||
V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]);
|
|
||||||
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B);
|
|
||||||
|
|
||||||
ggml_tensor * mask;
|
|
||||||
ggml_tensor * rw;
|
|
||||||
ggml_tensor * rh;
|
|
||||||
ggml_tensor * qr;
|
|
||||||
|
|
||||||
rw = get_rel_pos(ctx0, layer.rel_pos_w, indices, W, W); // [W, W, C]
|
|
||||||
rh = get_rel_pos(ctx0, layer.rel_pos_h, indices, H, H); // [H, H, C]
|
|
||||||
qr = ggml_permute(ctx0, Q, 0, 2, 1, 3);
|
|
||||||
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads);
|
|
||||||
|
|
||||||
|
|
||||||
rw = ggml_mul_mat (ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*n_heads, W, H, W]
|
|
||||||
rw = ggml_cont (ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*n_heads, H, W, W]
|
|
||||||
rw = ggml_reshape_4d(ctx0, rw, W, 1, W*H, n_heads*B);
|
|
||||||
rw = ggml_repeat_4d (ctx0, rw, W, H, W*H, n_heads*B);
|
|
||||||
rh = ggml_mul_mat (ctx0, rh, qr); // [B*n_heads, H, W, H]
|
|
||||||
rh = ggml_reshape_4d(ctx0, rh, 1, H, W*H, n_heads*B);
|
|
||||||
mask = ggml_add (ctx0, rw, rh); // [B*n_heads, H*W, H, W]
|
|
||||||
mask = ggml_reshape_4d(ctx0, mask, W*H, W*H, n_heads, B);
|
|
||||||
mask = ggml_cast (ctx0, mask, GGML_TYPE_F16);
|
|
||||||
|
|
||||||
float scale = 1.0f / sqrtf((float)d_heads);
|
|
||||||
|
|
||||||
cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale,
|
|
||||||
il); // [B, H*W, n_embd]
|
|
||||||
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hparams.is_global_attn(il) == false) {
|
|
||||||
// local attention layer - reverse window partition
|
|
||||||
cur = window_unpartition(ctx0, cur, w0, h0, window);
|
|
||||||
}
|
|
||||||
|
|
||||||
// re-add the layer input, e.g., residual
|
|
||||||
cur = ggml_add(ctx0, cur, shortcut);
|
|
||||||
|
|
||||||
ggml_tensor * inpFF = cur;
|
|
||||||
|
|
||||||
// layernorm2
|
|
||||||
cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
|
|
||||||
|
|
||||||
// ffn
|
|
||||||
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w,
|
|
||||||
layer.ff_down_b, hparams.ffn_op, il);
|
|
||||||
|
|
||||||
// residual 2
|
|
||||||
cur = ggml_add(ctx0, cur, inpFF);
|
|
||||||
cb(cur, "sam_layer_out", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
|
||||||
|
|
||||||
cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1);
|
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
|
||||||
cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
|
||||||
|
|
||||||
cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1);
|
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
|
||||||
cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
|
||||||
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
|
||||||
|
|
||||||
cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1);
|
|
||||||
cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1);
|
|
||||||
cb(cur, "sam_output", -1);
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
|
||||||
return cur;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * build_dsocr_clip(ggml_tensor * patch_embeds) {
|
|
||||||
ggml_tensor * inp;
|
|
||||||
|
|
||||||
inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds));
|
|
||||||
inp = ggml_reshape_2d(ctx0, inp, inp->ne[0]*inp->ne[1], inp->ne[2]);
|
|
||||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
|
||||||
|
|
||||||
ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings));
|
|
||||||
|
|
||||||
int n_pos = new_pos_embd->ne[1]; // +1 for [CLS]
|
|
||||||
const auto tgt_size = static_cast<int>(std::sqrt(inp->ne[1]));
|
|
||||||
const auto src_size = static_cast<int>(std::sqrt(n_pos - 1));
|
|
||||||
|
|
||||||
if (tgt_size != src_size) {
|
|
||||||
ggml_tensor * old_pos_embd;
|
|
||||||
ggml_tensor * cls_tok;
|
|
||||||
|
|
||||||
old_pos_embd = ggml_view_2d(
|
|
||||||
ctx0, new_pos_embd,
|
|
||||||
new_pos_embd->ne[0], src_size * src_size,
|
|
||||||
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0
|
|
||||||
);
|
|
||||||
cls_tok = ggml_view_2d(
|
|
||||||
ctx0, new_pos_embd,
|
|
||||||
new_pos_embd->ne[0], 1,
|
|
||||||
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size
|
|
||||||
);
|
|
||||||
new_pos_embd = ggml_interpolate(ctx0,
|
|
||||||
old_pos_embd,
|
|
||||||
tgt_size,
|
|
||||||
tgt_size,
|
|
||||||
new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC
|
|
||||||
);
|
|
||||||
new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1);
|
|
||||||
new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1);
|
|
||||||
n_pos = tgt_size * tgt_size + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add CLS token
|
|
||||||
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
|
|
||||||
|
|
||||||
// for selecting learned pos embd, used by ViT
|
|
||||||
ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32);
|
|
||||||
ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions);
|
|
||||||
|
|
||||||
ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK,
|
|
||||||
learned_pos_embd, nullptr);
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
|
||||||
|
|
||||||
return cur;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
||||||
GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
|
GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
|
||||||
|
|
@ -1127,9 +818,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
builder = std::make_unique<clip_graph_llava>(ctx, img);
|
builder = std::make_unique<clip_graph_llava>(ctx, img);
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
case PROJECTOR_TYPE_DEEPSEEKOCR:
|
||||||
{
|
{
|
||||||
res = graph.build_deepseek_ocr();
|
builder = std::make_unique<clip_graph_deepseekocr>(ctx, img);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("missing cgraph builder");
|
GGML_ABORT("missing cgraph builder");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,364 @@
|
||||||
|
#include "models.h"
|
||||||
|
|
||||||
|
ggml_tensor* clip_graph_deepseekocr::build_sam(ggml_tensor* inp_raw)
|
||||||
|
{
|
||||||
|
const int n_embd = hparams.sam_n_embd;
|
||||||
|
const int n_layer = hparams.sam_n_layer;
|
||||||
|
const int n_heads = hparams.sam_n_head;
|
||||||
|
const int d_heads = n_embd / n_heads;
|
||||||
|
const int window = hparams.attn_window_size;
|
||||||
|
|
||||||
|
ggml_tensor* inpL;
|
||||||
|
|
||||||
|
inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw);
|
||||||
|
inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, n_embd));
|
||||||
|
inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3));
|
||||||
|
|
||||||
|
ggml_tensor* rel_pos_indices_local;
|
||||||
|
ggml_tensor* rel_pos_indices_global;
|
||||||
|
|
||||||
|
rel_pos_indices_local = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, window, window);
|
||||||
|
rel_pos_indices_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, inpL->ne[1], inpL->ne[2]);
|
||||||
|
ggml_set_name(rel_pos_indices_local, "rel_pos_indices_local");
|
||||||
|
ggml_set_name(rel_pos_indices_global, "rel_pos_indices_global");
|
||||||
|
ggml_set_input(rel_pos_indices_local);
|
||||||
|
ggml_set_input(rel_pos_indices_global);
|
||||||
|
|
||||||
|
ggml_tensor* cur;
|
||||||
|
const auto tgt_size = inpL->ne[1];
|
||||||
|
const auto str_size = model.pos_embed->ne[1];
|
||||||
|
|
||||||
|
if (str_size != tgt_size)
|
||||||
|
{
|
||||||
|
ggml_tensor* old_pos_embed = nullptr;
|
||||||
|
old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3));
|
||||||
|
ggml_tensor* new_pos_embed = ggml_interpolate(
|
||||||
|
ctx0,
|
||||||
|
old_pos_embed,
|
||||||
|
tgt_size,
|
||||||
|
tgt_size,
|
||||||
|
n_embd,
|
||||||
|
1,
|
||||||
|
ggml_scale_mode::GGML_SCALE_MODE_BICUBIC
|
||||||
|
);
|
||||||
|
new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3));
|
||||||
|
cur = ggml_add(ctx0, inpL, new_pos_embed);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cur = ggml_add(ctx0, inpL, model.pos_embed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over layers
|
||||||
|
for (int il = 0; il < n_layer; il++)
|
||||||
|
{
|
||||||
|
auto& layer = model.sam_layers[il];
|
||||||
|
ggml_tensor* shortcut = cur;
|
||||||
|
|
||||||
|
// layernorm1
|
||||||
|
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
|
||||||
|
|
||||||
|
const int64_t w0 = cur->ne[1];
|
||||||
|
const int64_t h0 = cur->ne[2];
|
||||||
|
|
||||||
|
ggml_tensor* indices;
|
||||||
|
|
||||||
|
if (hparams.is_global_attn(il))
|
||||||
|
{
|
||||||
|
indices = rel_pos_indices_global;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// local attention layer - apply window partition
|
||||||
|
cur = window_partition(cur, window);
|
||||||
|
indices = rel_pos_indices_local;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t W = cur->ne[1];
|
||||||
|
const int64_t H = cur->ne[2];
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
const int B = cur->ne[3];
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||||
|
cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape
|
||||||
|
cur = ggml_reshape_4d(ctx0, cur, n_embd, 3, W * H, B);
|
||||||
|
|
||||||
|
ggml_tensor* Q;
|
||||||
|
ggml_tensor* K;
|
||||||
|
ggml_tensor* V;
|
||||||
|
|
||||||
|
Q = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 0 * cur->nb[1]);
|
||||||
|
Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W * H, B);
|
||||||
|
|
||||||
|
K = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 1 * cur->nb[1]);
|
||||||
|
K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W * H, B);
|
||||||
|
|
||||||
|
V = ggml_view_3d(ctx0, cur, n_embd, W * H, B, cur->nb[2], cur->nb[3], 2 * cur->nb[1]);
|
||||||
|
V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W * H, B);
|
||||||
|
|
||||||
|
ggml_tensor* mask;
|
||||||
|
ggml_tensor* rw;
|
||||||
|
ggml_tensor* rh;
|
||||||
|
ggml_tensor* qr;
|
||||||
|
|
||||||
|
rw = get_rel_pos(layer.rel_pos_w, indices, W, W); // [W, W, C]
|
||||||
|
rh = get_rel_pos(layer.rel_pos_h, indices, H, H); // [H, H, C]
|
||||||
|
qr = ggml_permute(ctx0, Q, 0, 2, 1, 3);
|
||||||
|
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads);
|
||||||
|
|
||||||
|
|
||||||
|
rw = ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*n_heads, W, H, W]
|
||||||
|
rw = ggml_cont(ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*n_heads, H, W, W]
|
||||||
|
rw = ggml_reshape_4d(ctx0, rw, W, 1, W * H, n_heads * B);
|
||||||
|
rw = ggml_repeat_4d(ctx0, rw, W, H, W * H, n_heads * B);
|
||||||
|
rh = ggml_mul_mat(ctx0, rh, qr); // [B*n_heads, H, W, H]
|
||||||
|
rh = ggml_reshape_4d(ctx0, rh, 1, H, W * H, n_heads * B);
|
||||||
|
mask = ggml_add(ctx0, rw, rh); // [B*n_heads, H*W, H, W]
|
||||||
|
mask = ggml_reshape_4d(ctx0, mask, W * H, W * H, n_heads, B);
|
||||||
|
mask = ggml_cast(ctx0, mask, GGML_TYPE_F16);
|
||||||
|
|
||||||
|
float scale = 1.0f / sqrtf((float)d_heads);
|
||||||
|
|
||||||
|
cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale,
|
||||||
|
il); // [B, H*W, n_embd]
|
||||||
|
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hparams.is_global_attn(il) == false)
|
||||||
|
{
|
||||||
|
// local attention layer - reverse window partition
|
||||||
|
cur = window_unpartition(cur, w0, h0, window);
|
||||||
|
}
|
||||||
|
|
||||||
|
// re-add the layer input, e.g., residual
|
||||||
|
cur = ggml_add(ctx0, cur, shortcut);
|
||||||
|
|
||||||
|
ggml_tensor* inpFF = cur;
|
||||||
|
|
||||||
|
// layernorm2
|
||||||
|
cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
|
||||||
|
|
||||||
|
// ffn
|
||||||
|
cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w,
|
||||||
|
layer.ff_down_b, hparams.ffn_op, il);
|
||||||
|
|
||||||
|
// residual 2
|
||||||
|
cur = ggml_add(ctx0, cur, inpFF);
|
||||||
|
cb(cur, "sam_layer_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
||||||
|
|
||||||
|
cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1);
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||||
|
cur = build_norm(cur, model.neck_1_w, model.neck_1_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
||||||
|
|
||||||
|
cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1);
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
||||||
|
cur = build_norm(cur, model.neck_3_w, model.neck_3_b, NORM_TYPE_NORMAL, hparams.eps, -1);
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
|
||||||
|
|
||||||
|
cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1);
|
||||||
|
cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1);
|
||||||
|
cb(cur, "sam_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ggml_cgraph* clip_graph_deepseekocr::build()
|
||||||
|
{
|
||||||
|
//patch embedding
|
||||||
|
ggml_tensor* inp_raw = build_inp_raw();
|
||||||
|
ggml_tensor* sam_out = build_sam(inp_raw);
|
||||||
|
ggml_tensor* clip_out = build_dsocr_clip(sam_out);
|
||||||
|
|
||||||
|
int clip_n_patches = sam_out->ne[0] * sam_out->ne[1];
|
||||||
|
|
||||||
|
sam_out = ggml_cont(ctx0, ggml_permute(ctx0, sam_out, 1, 2, 0, 3));
|
||||||
|
sam_out = ggml_reshape_2d(ctx0, sam_out, sam_out->ne[0], clip_n_patches);
|
||||||
|
clip_out = ggml_view_2d(ctx0, clip_out, n_embd, clip_n_patches, clip_out->nb[1], clip_out->nb[1]);
|
||||||
|
|
||||||
|
ggml_tensor* cur;
|
||||||
|
cur = ggml_concat(ctx0, clip_out, sam_out, 0);
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, 2 * n_embd, clip_n_patches);
|
||||||
|
cur = ggml_cont(ctx0, cur);
|
||||||
|
cur = ggml_mul_mat(ctx0, model.fc_w, cur);
|
||||||
|
cur = ggml_add(ctx0, cur, model.fc_b);
|
||||||
|
|
||||||
|
const auto h = static_cast<int>(std::sqrt(static_cast<float>(cur->ne[1])));
|
||||||
|
const auto w = h;
|
||||||
|
const auto n_dim = cur->ne[0];
|
||||||
|
|
||||||
|
ggml_tensor* imgnl;
|
||||||
|
ggml_tensor* vs;
|
||||||
|
|
||||||
|
imgnl = ggml_repeat_4d(ctx0, model.image_newline, n_dim, 1, h, 1);
|
||||||
|
vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1)
|
||||||
|
cur = ggml_reshape_3d(ctx0, cur, n_dim, w, h);
|
||||||
|
cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w + 1) * h);
|
||||||
|
cur = ggml_concat(ctx0, cur, vs, 1); // (n_dim, h*(w+1) + 1)
|
||||||
|
|
||||||
|
cb(cur, "dsocr_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* clip_graph_deepseekocr::build_dsocr_clip(ggml_tensor* patch_embeds)
|
||||||
|
{
|
||||||
|
ggml_tensor* inp;
|
||||||
|
|
||||||
|
inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds));
|
||||||
|
inp = ggml_reshape_2d(ctx0, inp, inp->ne[0] * inp->ne[1], inp->ne[2]);
|
||||||
|
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||||
|
|
||||||
|
ggml_tensor* new_pos_embd = ggml_cpy(ctx0, model.position_embeddings,
|
||||||
|
ggml_dup_tensor(ctx0, model.position_embeddings));
|
||||||
|
|
||||||
|
int n_pos = new_pos_embd->ne[1]; // +1 for [CLS]
|
||||||
|
const auto tgt_size = static_cast<int>(std::sqrt(inp->ne[1]));
|
||||||
|
const auto src_size = static_cast<int>(std::sqrt(n_pos - 1));
|
||||||
|
|
||||||
|
if (tgt_size != src_size)
|
||||||
|
{
|
||||||
|
ggml_tensor* old_pos_embd;
|
||||||
|
ggml_tensor* cls_tok;
|
||||||
|
|
||||||
|
old_pos_embd = ggml_view_2d(
|
||||||
|
ctx0, new_pos_embd,
|
||||||
|
new_pos_embd->ne[0], src_size * src_size,
|
||||||
|
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), 0
|
||||||
|
);
|
||||||
|
cls_tok = ggml_view_2d(
|
||||||
|
ctx0, new_pos_embd,
|
||||||
|
new_pos_embd->ne[0], 1,
|
||||||
|
ggml_row_size(new_pos_embd->type, new_pos_embd->ne[0]), src_size * src_size
|
||||||
|
);
|
||||||
|
new_pos_embd = ggml_interpolate(ctx0,
|
||||||
|
old_pos_embd,
|
||||||
|
tgt_size,
|
||||||
|
tgt_size,
|
||||||
|
new_pos_embd->ne[0], 1, GGML_SCALE_MODE_BICUBIC
|
||||||
|
);
|
||||||
|
new_pos_embd = ggml_reshape_3d(ctx0, new_pos_embd, n_embd, tgt_size * tgt_size, 1);
|
||||||
|
new_pos_embd = ggml_concat(ctx0, new_pos_embd, cls_tok, 1);
|
||||||
|
n_pos = tgt_size * tgt_size + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add CLS token
|
||||||
|
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
|
||||||
|
|
||||||
|
// for selecting learned pos embd, used by ViT
|
||||||
|
ggml_tensor* positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32);
|
||||||
|
ggml_tensor* learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions);
|
||||||
|
|
||||||
|
ggml_tensor* cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, ffn_op_type::FFN_GELU_QUICK,
|
||||||
|
learned_pos_embd, nullptr);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * clip_graph_deepseekocr::get_rel_pos(
|
||||||
|
ggml_tensor * rel_pos, // [L, C]
|
||||||
|
ggml_tensor * indices, // [q_size, k_size]
|
||||||
|
int q_size,
|
||||||
|
int k_size
|
||||||
|
) {
|
||||||
|
const int64_t C = rel_pos->ne[0]; // channels
|
||||||
|
const int64_t L = rel_pos->ne[1]; // length
|
||||||
|
|
||||||
|
GGML_ASSERT(indices != nullptr);
|
||||||
|
GGML_ASSERT(indices->type == GGML_TYPE_I32);
|
||||||
|
GGML_ASSERT(indices->ne[0] == k_size);
|
||||||
|
GGML_ASSERT(indices->ne[1] == q_size);
|
||||||
|
|
||||||
|
const auto max_rel_dist = 2*std::max(q_size, k_size) - 1;
|
||||||
|
ggml_tensor * cur = rel_pos;
|
||||||
|
|
||||||
|
if (max_rel_dist != L) {
|
||||||
|
// Linear interpolation
|
||||||
|
int64_t ne0 = cur->ne[0];
|
||||||
|
int64_t ne1 = cur->ne[1];
|
||||||
|
int64_t ne2 = cur->ne[2];
|
||||||
|
int64_t ne3 = cur->ne[3];
|
||||||
|
|
||||||
|
cur = ggml_reshape_3d(
|
||||||
|
ctx0,
|
||||||
|
ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3)),
|
||||||
|
ne1, 1, ne0*ne2*ne3
|
||||||
|
);
|
||||||
|
cur = ggml_reshape_4d(
|
||||||
|
ctx0,
|
||||||
|
ggml_interpolate(
|
||||||
|
ctx0,
|
||||||
|
cur,
|
||||||
|
max_rel_dist, 1, ne0*ne2*ne3, 1,
|
||||||
|
ggml_scale_mode::GGML_SCALE_MODE_BILINEAR
|
||||||
|
),
|
||||||
|
max_rel_dist, ne0, ne2, ne3
|
||||||
|
);
|
||||||
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten indices to 1D for ggml_get_rows
|
||||||
|
int qk = q_size * k_size;
|
||||||
|
|
||||||
|
cur = ggml_reshape_3d(
|
||||||
|
ctx0,
|
||||||
|
ggml_get_rows(ctx0, cur, ggml_reshape_1d(ctx0, indices, qk)),
|
||||||
|
C, k_size, q_size
|
||||||
|
);
|
||||||
|
|
||||||
|
return cur; // [C, k_size, q_size]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation based on approach suggested by Acly
|
||||||
|
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
|
||||||
|
ggml_tensor* clip_graph_deepseekocr::window_partition(ggml_tensor* x, int window) {
|
||||||
|
auto [c, w, h, b] = x->ne;
|
||||||
|
// same as
|
||||||
|
// x = ggml_win_part(m, x, window);
|
||||||
|
// x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]);
|
||||||
|
|
||||||
|
int64_t px = (window - w % window) % window;
|
||||||
|
int64_t py = (window - h % window) % window;
|
||||||
|
int64_t npw = (w + px) / window;
|
||||||
|
int64_t nph = (h + py) / window;
|
||||||
|
|
||||||
|
if (px > 0 || py > 0) {
|
||||||
|
x = ggml_pad(ctx0, x, 0, int(px), int(py), 0);
|
||||||
|
}
|
||||||
|
x = ggml_reshape_4d(ctx0, x, c * window, npw, window, nph * b);
|
||||||
|
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
|
||||||
|
x = ggml_reshape_4d(ctx0, x, c, window, window, npw * nph * b);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation based on approach suggested by Acly
|
||||||
|
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091
|
||||||
|
ggml_tensor* clip_graph_deepseekocr::window_unpartition(ggml_tensor* x, int w, int h, int window) {
|
||||||
|
int64_t c = x->ne[0];
|
||||||
|
// same as
|
||||||
|
// x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]);
|
||||||
|
// x = ggml_win_unpart(m, x, w, h, window);
|
||||||
|
|
||||||
|
int64_t px = (window - w % window) % window;
|
||||||
|
int64_t py = (window - h % window) % window;
|
||||||
|
int64_t npw = (w + px) / window;
|
||||||
|
int64_t nph = (h + py) / window;
|
||||||
|
|
||||||
|
int64_t b = x->ne[3] / (npw * nph);
|
||||||
|
x = ggml_reshape_4d(ctx0, x, c * window, window, npw, nph * b);
|
||||||
|
x = ggml_cont(ctx0, ggml_permute(ctx0, x, 0, 2, 1, 3));
|
||||||
|
x = ggml_reshape_4d(ctx0, x, c, w + px, h + py, b);
|
||||||
|
x = ggml_view_4d(ctx0, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
|
||||||
|
x = ggml_cont(ctx0, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
@ -56,3 +56,14 @@ struct clip_graph_whisper_enc : clip_graph {
|
||||||
clip_graph_whisper_enc(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
clip_graph_whisper_enc(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||||
ggml_cgraph * build() override;
|
ggml_cgraph * build() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct clip_graph_deepseekocr : clip_graph {
|
||||||
|
clip_graph_deepseekocr(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||||
|
ggml_cgraph * build() override;
|
||||||
|
|
||||||
|
ggml_tensor * build_sam(ggml_tensor * inp_raw);
|
||||||
|
ggml_tensor * build_dsocr_clip(ggml_tensor * patch_embeds);
|
||||||
|
ggml_tensor * get_rel_pos(ggml_tensor * rel_pos, ggml_tensor * indices, int q_size, int k_size);
|
||||||
|
ggml_tensor * window_partition(ggml_tensor * x, int window);
|
||||||
|
ggml_tensor * window_unpartition(ggml_tensor * x, int w, int h, int window);
|
||||||
|
};
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ ggml_cgraph * clip_graph_siglip::build() {
|
||||||
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
|
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
|
||||||
const int scale_factor = model.hparams.n_merge;
|
const int scale_factor = model.hparams.n_merge;
|
||||||
cur = build_patch_merge_permute(cur, scale_factor);
|
cur = build_patch_merge_permute(cur, scale_factor);
|
||||||
cur = ggml_mul_mat(ctx0, model.projection, cur);
|
cur = ggml_mul_mat(ctx0, model.fc_w, cur);
|
||||||
|
|
||||||
} else if (proj_type == PROJECTOR_TYPE_LFM2) {
|
} else if (proj_type == PROJECTOR_TYPE_LFM2) {
|
||||||
// pixel unshuffle block
|
// pixel unshuffle block
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue