mtmd: simplify SAM patch embedding
This commit is contained in:
parent
c5f4c64fe4
commit
95239f92b9
|
|
@ -663,28 +663,24 @@ struct clip_graph {
|
|||
return gf;
|
||||
}
|
||||
|
||||
ggml_tensor * build_sam_enc(ggml_tensor * inp_raw,
|
||||
const int enc_image_size = 1024
|
||||
) {
|
||||
ggml_tensor * build_sam_enc(ggml_tensor * inp_raw) {
|
||||
constexpr int enc_n_embd = 768;
|
||||
constexpr int _depth = 12;
|
||||
constexpr int enc_n_heads = 12;
|
||||
constexpr int enc_d_heads = enc_n_embd / enc_n_heads;
|
||||
// constexpr int _prompt_n_embd = 256;
|
||||
constexpr int enc_patch_size = 16;
|
||||
// constexpr int _window_size = 14;
|
||||
|
||||
const int enc_n_patches = enc_image_size / enc_patch_size; // 64
|
||||
|
||||
ggml_tensor * inpL = build_enc_inp(inp_raw, enc_patch_size, enc_n_patches, enc_n_embd);
|
||||
ggml_tensor * cur = nullptr;
|
||||
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = ggml_conv_2d_sk_p0(ctx0, model.patch_embed_proj_w, inp_raw);
|
||||
inpL = ggml_add(ctx0, inpL, ggml_reshape_3d(ctx0, model.patch_embed_proj_b, 1, 1, enc_n_embd));
|
||||
inpL = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 1, 2, 0, 3));
|
||||
|
||||
ggml_tensor * cur;
|
||||
const auto tgt_size = inpL->ne[1];
|
||||
const auto str_size = model.pos_embed->ne[1];
|
||||
if (str_size != tgt_size) {
|
||||
ggml_tensor * old_pos_embed = nullptr;
|
||||
old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3));
|
||||
// TODO: ggml_interpolate doesn't support bicubic model for CUDA backend
|
||||
ggml_tensor * new_pos_embed = ggml_interpolate(
|
||||
ctx0,
|
||||
old_pos_embed,
|
||||
|
|
@ -838,7 +834,7 @@ struct clip_graph {
|
|||
ggml_cgraph * build_deepseek_ocr() {
|
||||
//patch embedding
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny));
|
||||
ggml_tensor * global_features_1 = build_sam_enc(inp_raw);
|
||||
ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1);
|
||||
|
||||
// FIXME remove n_patches is hardcoded
|
||||
|
|
@ -5819,6 +5815,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
bool is_stored = false;
|
||||
std::vector<std::string> patterns = {
|
||||
/* Add tensor names here to dump (e.g. "sam_output") */
|
||||
"inpL", "inp_raw_cpy"
|
||||
};
|
||||
|
||||
for (auto & p : patterns) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue