mtmd: fix get_rel_pos

This commit is contained in:
bluebread 2025-11-21 17:12:12 +00:00
parent 5e6cf3c6a8
commit 7e9fbeccc5
1 changed files with 85 additions and 79 deletions

View File

@ -2467,101 +2467,107 @@ private:
}
// attn: [k_h*k_w, q_h*q_w]
// rel_h: [q_h, q_w, k_h]
// rel_w: [q_h, q_w, k_w]
// rel_h: [q_h, q_w, k_h]
// rel_w: [q_h, q_w, k_w]
static ggml_tensor * add_rel_pos_inplace(
ggml_context * ctx,
ggml_tensor * attn,
ggml_tensor * rel_w,
ggml_tensor * rel_h,
int q_size
) {
static ggml_tensor * add_rel_pos_inplace(
ggml_context * ctx,
ggml_tensor * attn,
ggml_tensor * rel_w,
ggml_tensor * rel_h,
int q_size
) {
ggml_tensor *attn_4d =
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
ggml_tensor *attn_4d =
ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]);
ggml_tensor *rel_h_4d =
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
ggml_tensor *rel_h_4d =
ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]);
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d
ggml_tensor *rel_w_4d =
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
ggml_tensor *rel_w_4d =
ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]);
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep));
result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);
return result;
}
return result;
}
static ggml_tensor * get_rel_pos(
ggml_context * ctx,
ggml_tensor * rel_pos, // [L, C]
int q_size,
int k_size
) {
static ggml_tensor * get_rel_pos(
ggml_context * ctx,
ggml_tensor * rel_pos, // [L, C]
int q_size,
int k_size
) {
const int64_t C = rel_pos->ne[0]; // channels
const int64_t L = rel_pos->ne[1]; // length
const auto dtype = rel_pos->type;
GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L);
const int64_t L = rel_pos->ne[0]; // length
const int64_t C = rel_pos->ne[1]; // channels
// -------------------------------------------------
// 1) q_idx ← arange(0..q_size-1) [q_size]
// 2) k_idx ← arange(0..k_size-1) [k_size]
// -------------------------------------------------
// -------------------------------------------------
// 1) q_idx ← arange(0..q_size-1) [q_size]
// 2) k_idx ← arange(0..k_size-1) [k_size]
// -------------------------------------------------
// ggml_arange always returns FP32 tensor
ggml_tensor * q_coord = ggml_arange(ctx, 0.0f, static_cast<float>(q_size), 1.0f); // [q_size]
ggml_tensor * k_coord = ggml_arange(ctx, 0.0f, static_cast<float>(k_size), 1.0f); // [k_size]
ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k_size, q_size);
// broadcast reshape:
q_coord = ggml_cont(ctx,
ggml_repeat(ctx,
ggml_reshape_2d(ctx, q_coord, 1, q_size), // [q_size, 1]
rel
)
); // [q_size, k_size]
k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size]
// This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with
// the original implementation.
if (q_size != k_size) {
q_coord = ggml_scale_inplace(ctx, q_coord, std::max((float)k_size/q_size, 1.0f));
k_coord = ggml_scale_inplace(ctx, k_coord, std::max((float)q_size/k_size, 1.0f));
}
// -------------------------------------------------
// relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling
// -------------------------------------------------
rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size]
rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast<float>(k_size) - 1.0f); // [q_size, k_size]
// Clamp to [0, L-1] range for valid indexing
rel = ggml_clamp(ctx, rel, 0.0f, static_cast<float>(rel_pos->ne[1] - 1));
// -------------------------------------------------
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
// -------------------------------------------------
ggml_tensor * idx_2d = ggml_cast(ctx, rel, GGML_TYPE_I32); // [q_size, k_size]
// Gather from rel_pos → [qk, C]
// -------------------------------------------------
// flatten to 1D for ggml_get_rows
int qk = q_size * k_size;
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
// -------------------------------------------------
// Gather from rel_pos → [qk, C]
// -------------------------------------------------
ggml_tensor * out = ggml_reshape_3d(ctx, gathered, C, k_size, q_size); // [qk, C]
ggml_tensor * q_coord = ggml_cast(ctx,
ggml_arange(ctx, 0.0f, static_cast<float>(q_size), 1.0f),
GGML_TYPE_F32); // [q_size]
ggml_tensor * k_coord = ggml_cast(ctx,
ggml_arange(ctx, 0.0f, static_cast<float>(k_size), 1.0f),
GGML_TYPE_F32); // [k_size]
ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size);
q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size]
// broadcast reshape:
k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size]
k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size]
// -------------------------------------------------
// relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling
// -------------------------------------------------
rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size]
rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast<float>(k_size) - 1.0f); // [q_size, k_size]
// -------------------------------------------------
// clamp to [0, L-1] and cast to int32 (for ggml_get_rows)
// -------------------------------------------------
ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast<float>(L - 1));
ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size]
// flatten to 1D for ggml_get_rows
const int64_t qk = static_cast<int64_t>(q_size) * static_cast<int64_t>(k_size);
ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk]
// -------------------------------------------------
// Gather from rel_pos → [qk, C]
// -------------------------------------------------
ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C]
// reshape to final output → [q_size, k_size, C]
ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0],
q_size,
k_size);
return out; // [q_size, k_size, C]
}
return out; // [q_size, k_size, C]
}
// Implementation based on approach suggested by Acly
// See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091