minor formatting and style

This commit is contained in:
Saba Fallah 2025-12-05 09:30:58 +01:00
parent 076138a428
commit f5bd310a5e
1 changed files with 6 additions and 13 deletions

View File

@ -2598,8 +2598,8 @@ private:
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cur = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
}
@ -2801,7 +2801,8 @@ private:
rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C]
rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C]
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)), d_heads, W, H, B*n_heads);
qr = ggml_permute(ctx0, Q, 0, 2, 1, 3);
qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads);
const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H;
@ -2818,16 +2819,8 @@ private:
float scale = 1.0f / sqrtf((float)d_heads);
cur = build_attn(
layer.o_w,
layer.o_b,
Q,
K,
V,
mask,
scale,
il
); // [B, H*W, n_embd]
cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale,
il); // [B, H*W, n_embd]
cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B);
}