fix computation

This commit is contained in:
Aman Gupta 2026-02-11 11:34:47 +05:30
parent 4be60a28b6
commit ffe3e82c8b
2 changed files with 2 additions and 3 deletions

View File

@ -10399,7 +10399,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps)); ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps));
// state decay: S *= exp(g) // state decay: S *= exp(g)
// s_t is row-major, but scaling all elements is layout-agnostic
ggml_vec_scale_f32(S_v * S_v, s_t, g_val); ggml_vec_scale_f32(S_v * S_v, s_t, g_val);
// kv_mem[j] = sum_i S[j][i] * k[i] = dot(s_t[j*S_v:], k) // kv_mem[j] = sum_i S[j][i] * k[i] = dot(s_t[j*S_v:], k)

View File

@ -6141,8 +6141,8 @@ struct ggml_tensor * ggml_gated_delta_net(
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
// concat output and new_state into a single tensor // concat output and new_state into a single tensor
// output: S_v * H * n_tokens, state: S_v * S_v * H * n_seqs // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
const int64_t ne[4] = { S_v * H, n_tokens + S_v * n_seqs, 1, 1 }; const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_GATED_DELTA_NET; result->op = GGML_OP_GATED_DELTA_NET;