mtmd: minor changed
This commit is contained in:
parent
7b8d735c90
commit
effe66958e
|
|
@ -739,13 +739,14 @@ struct clip_graph {
|
||||||
|
|
||||||
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
|
struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads);
|
||||||
|
|
||||||
struct ggml_tensor * rel_w = ggml_cont(
|
struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0,
|
||||||
ctx0,
|
ggml_mul_mat(ctx0,
|
||||||
ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0,
|
rw,
|
||||||
2, 1, 3));
|
ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))),
|
||||||
|
0, 2, 1, 3));
|
||||||
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
|
struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
|
||||||
|
|
||||||
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W);
|
struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
|
||||||
|
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
|
||||||
|
|
||||||
|
|
@ -2466,7 +2467,7 @@ private:
|
||||||
return inpL;
|
return inpL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// attn: [k_h*k_w, q_h*q_w]
|
// attn: [q_h*q_w, k_h*k_w]
|
||||||
// rel_h: [q_h, q_w, k_h]
|
// rel_h: [q_h, q_w, k_h]
|
||||||
// rel_w: [q_h, q_w, k_w]
|
// rel_w: [q_h, q_w, k_w]
|
||||||
|
|
||||||
|
|
@ -2474,24 +2475,29 @@ private:
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_tensor * attn,
|
ggml_tensor * attn,
|
||||||
ggml_tensor * rel_w,
|
ggml_tensor * rel_w,
|
||||||
ggml_tensor * rel_h,
|
ggml_tensor * rel_h
|
||||||
int q_size
|
|
||||||
) {
|
) {
|
||||||
|
const int k_w = rel_w->ne[0];
|
||||||
|
const int k_h = rel_h->ne[0];
|
||||||
|
const int q_w = rel_h->ne[1];
|
||||||
|
const int q_h = rel_h->ne[2];
|
||||||
|
|
||||||
ggml_tensor *attn_4d =
|
GGML_ASSERT(q_w == rel_w->ne[1]);
|
||||||
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
|
GGML_ASSERT(q_h == rel_w->ne[2]);
|
||||||
|
GGML_ASSERT(attn->ne[0] == k_h*k_w);
|
||||||
|
GGML_ASSERT(attn->ne[1] == q_h*q_w);
|
||||||
|
|
||||||
ggml_tensor *rel_h_4d =
|
ggml_tensor *attn_4d = ggml_reshape_4d(ctx, attn, k_w, k_h, attn->ne[1], attn->ne[2]);
|
||||||
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
|
|
||||||
|
ggml_tensor *rel_h_4d = ggml_reshape_4d(ctx, rel_h, 1, k_h, attn->ne[1], attn->ne[2]);
|
||||||
|
|
||||||
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
|
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
|
||||||
|
|
||||||
ggml_tensor *rel_w_4d =
|
ggml_tensor *rel_w_4d = ggml_reshape_4d(ctx, rel_w, k_w, 1, attn->ne[1], attn->ne[2]);
|
||||||
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
|
|
||||||
|
|
||||||
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
|
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
|
||||||
|
|
||||||
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
|
ggml_tensor * result = ggml_add_inplace(ctx, attn_4d, ggml_add_inplace(ctx, rel_h_rep, rel_w_rep));
|
||||||
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
|
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue