mtmd: simplify SAM patch embedding

This commit is contained in:
bluebread 2025-12-01 07:31:24 +00:00
parent c5f4c64fe4
commit 95239f92b9
1 changed files with 10 additions and 13 deletions

View File

@ -663,28 +663,24 @@ struct clip_graph {
return gf;
}
ggml_tensor * build_sam_enc(ggml_tensor * inp_raw,
const int enc_image_size = 1024
) {
ggml_tensor * build_sam_enc(ggml_tensor * inp_raw) {
constexpr int enc_n_embd = 768;
constexpr int _depth = 12;
constexpr int enc_n_heads = 12;
constexpr int enc_d_heads = enc_n_embd / enc_n_heads;
// constexpr int _prompt_n_embd = 256;
constexpr int enc_patch_size = 16;
// constexpr int _window_size = 14;
const int enc_n_patches = enc_image_size / enc_patch_size; // 64
ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd);
ggml_tensor * cur = nullptr;
ggml_tensor * inpL;
inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw);
inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, enc_n_embd));
inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3));
ggml_tensor * cur;
const auto tgt_size = inpL->ne[1];
const auto str_size = model.pos_embed->ne[1];
if (str_size != tgt_size) {
ggml_tensor * old_pos_embed = nullptr;
old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3));
// TODO: ggml_interpolate doesn't support bicubic model for CUDA backend
ggml_tensor * new_pos_embed = ggml_interpolate(
ctx0,
old_pos_embed,
@ -838,7 +834,7 @@ struct clip_graph {
ggml_cgraph * build_deepseek_ocr() {
//patch embedding
ggml_tensor * inp_raw = build_inp_raw();
ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny));
ggml_tensor * global_features_1 = build_sam_enc(inp_raw);
ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1);
// FIXME remove n_patches is hardcoded
@ -5819,6 +5815,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
bool is_stored = false;
std::vector<std::string> patterns = {
/* Add tensor names here to dump (e.g. "sam_output") */
"inpL", "inp_raw_cpy"
};
for (auto & p : patterns) {