window partitioning using standard ggml ops

This commit is contained in:
Saba Fallah 2025-11-20 10:07:54 +01:00
parent 89afda8da9
commit 88032f46b1
1 changed files with 46 additions and 4 deletions

View File

@ -690,7 +690,8 @@ struct clip_graph {
if (hparams.is_global_attn(il) == false) { if (hparams.is_global_attn(il) == false) {
// local attention layer - apply window partition // local attention layer - apply window partition
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172
cur = ggml_win_part(ctx0, cur, 14); //cur = ggml_win_part(ctx0, cur, 14);
cur = window_partition(ctx0, cur, 14);
} }
const int64_t W = cur->ne[1]; const int64_t W = cur->ne[1];
@ -762,7 +763,7 @@ struct clip_graph {
if (hparams.is_global_attn(il) == false) { if (hparams.is_global_attn(il) == false) {
// local attention layer - reverse window partition // local attention layer - reverse window partition
cur = ggml_win_unpart(ctx0, cur, w0, h0, 14); cur = window_unpartition(ctx0, cur, w0, h0, 14);
} }
// re-add the layer input, e.g., residual // re-add the layer input, e.g., residual
@ -865,9 +866,10 @@ struct clip_graph {
// 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim]
ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h) ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, 1280, 64, 64, 1); // (n_dim, w, h)
t = ggml_permute(ctx0, t, 2, 1, 0, 3); // (h, w, n_dim) t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim)
ggml_tensor * nl = ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3); ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3));
nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows
nl = ggml_cont(ctx0, nl);
// 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim]
@ -2464,6 +2466,46 @@ private:
return inpL; return inpL;
} }
static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) {
auto [c, w, h, b] = x->ne;
// same as
// x = ggml_win_part(m, x, window);
// x = ggml_reshape_3d(m, x, c, window * window, x->ne[3]);
int64_t px = (window - w % window) % window;
int64_t py = (window - h % window) % window;
int64_t npw = (w + px) / window;
int64_t nph = (h + py) / window;
if (px > 0 || py > 0) {
x = ggml_pad(ctx, x, 0, int(px), int(py), 0);
}
x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b);
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3));
x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b);
return x;
}
static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) {
int64_t c = x->ne[0];
// same as
// x = ggml_reshape_4d(m, x, c, window, window, x->ne[2]);
// x = ggml_win_unpart(m, x, w, h, window);
int64_t px = (window - w % window) % window;
int64_t py = (window - h % window) % window;
int64_t npw = (w + px) / window;
int64_t nph = (h + py) / window;
int64_t b = x->ne[3] / (npw * nph);
x = ggml_reshape_4d(m, x, c * window, window, npw, nph * b);
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
x = ggml_reshape_4d(m, x, c, w + px, h + py, b);
x = ggml_view_4d(m, x, x->ne[0], w, h, x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
x = ggml_cont(m, x);
return x;
}
// build the input after conv2d (inp_raw --> patches) // build the input after conv2d (inp_raw --> patches)
// returns tensor with shape [n_embd, n_patches] // returns tensor with shape [n_embd, n_patches]
ggml_tensor * build_enc_inp(ggml_tensor * inp_raw, ggml_tensor * build_enc_inp(ggml_tensor * inp_raw,