From 8666c546f953c73d71c8736a03b334554e22326c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 11 Feb 2026 10:52:49 +0530 Subject: [PATCH 1/9] ggml: add GATED_DELTA_NET op --- ggml/include/ggml.h | 10 ++ ggml/src/ggml-cpu/ggml-cpu.c | 10 ++ ggml/src/ggml-cpu/ops.cpp | 180 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml.c | 59 +++++++++++- 5 files changed, 258 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f759e2d588..82b27d2343 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -556,6 +556,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_GATED_DELTA_NET, GGML_OP_UNARY, @@ -2466,6 +2467,15 @@ extern "C" { bool lower, bool uni); + GGML_API struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b003fe13fd..f11d7eab8f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_GATED_DELTA_NET: + { + ggml_compute_forward_gated_delta_net(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_COUNT_EQUAL: case GGML_OP_SOLVE_TRI: + case GGML_OP_GATED_DELTA_NET: { n_tasks = n_threads; } break; @@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan( { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; + case GGML_OP_GATED_DELTA_NET: + { + const int64_t S_v = node->src[0]->ne[0]; + cur = 4 * S_v * sizeof(float) * n_tasks; + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ed45350207..0069d8c0f5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10296,6 +10296,186 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s } } +// ggml_compute_forward_gated_delta_net +static void ggml_compute_forward_gated_delta_net_one_chunk( + const ggml_compute_params * params, + ggml_tensor * dst, + int64_t ir0, + int64_t ir1) { + + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_q->ne[0]; + const int64_t H = src_q->ne[1]; + const int64_t n_tokens = src_q->ne[2]; + const int64_t n_seqs = src_q->ne[3]; + + GGML_ASSERT(ggml_is_contiguous(src_q)); + GGML_ASSERT(ggml_is_contiguous(src_k)); + GGML_ASSERT(ggml_is_contiguous(src_v)); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // scratch layout per thread: [q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)] + const int64_t scratch_per_thread = 4 * S_v; + const int ith = params->ith; + float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + + float * q_local = scratch; + float * k_local = scratch + S_v; + float * kv_mem = scratch + 2 * S_v; + float * delta = scratch + 3 * S_v; + + // output layout: [attn_scores | new_states] + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs floats + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_out_base = (float *)dst->data; + float * state_out_base = (float *)dst->data + attn_score_elems; + + const float * state_in_base = (const float *)src_state->data; + const float * q_base = (const float *)src_q->data; + const float * k_base = (const float *)src_k->data; + const float * v_base = (const float *)src_v->data; + const float * g_base = (const float *)src_g->data; + const float * beta_base = (const float *)src_beta->data; + + const float eps = ggml_get_op_params_f32(dst, 0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t h_idx = ir % H; + const int64_t sequence = ir / H; + + // output state pointer for this (head, seq) + float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v; + + // copy input state for this (head, seq) into output + const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v; + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + + // attn output pointer for first token of this (head, seq) + float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v; + + for (int64_t t = 0; t < n_tokens; t++) { + // input pointers for this (head, seq, token) + // layout is contiguous [S_v, H, n_tokens, n_seqs] + const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v; + const float * q_d = q_base + qkv_offset; + const float * k_d = k_base + qkv_offset; + const float * v_d = v_base + qkv_offset; + + // g and beta layout: [H, n_tokens, n_seqs] + const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; + const float beta_val_raw = beta_base[gb_offset]; + const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid + const float g_val = expf(g_base[gb_offset]); + + memcpy(q_local, q_d, S_v * sizeof(float)); + memcpy(k_local, k_d, S_v * sizeof(float)); + + // l2-norm q and scale by 1/sqrt(S_v) + float norm; + ggml_vec_norm_f32(S_v, &norm, q_local); + ggml_vec_scale_f32(S_v, q_local, 1.0f / fmaxf(norm, eps)); + ggml_vec_scale_f32(S_v, q_local, 1.0f / sqrtf((float)S_v)); + + // l2-norm k + ggml_vec_norm_f32(S_v, &norm, k_local); + ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps)); + + // state decay: S *= exp(g) + ggml_vec_scale_f32(S_v * S_v, s_out, g_val); + + // kv_mem = S @ k + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_dot_f32(S_v, &kv_mem[i], 0, &s_out[i * S_v], 0, k_local, 0, 1); + } + + // delta = (v - kv_mem) * beta + for (int64_t i = 0; i < S_v; ++i) { + delta[i] = (v_d[i] - kv_mem[i]) * beta_val; + } + + // outer product update: S += k (x) delta + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_local[i]); + } + + // attn output = S @ q + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_dot_f32(S_v, &attn_data[i], 0, &s_out[i * S_v], 0, q_local, 0, 1); + } + + attn_data += S_v * H; // advance to next token + } + } +} + + +static void ggml_compute_forward_gated_delta_net_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + ggml_tensor * Q = dst->src[0]; + int64_t nr = Q->ne[1] * Q->ne[3]; + + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + int nth = params->nth; + int ith = params->ith; + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1); + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } +} + +void ggml_compute_forward_gated_delta_net( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gated_delta_net_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_rwkv_wkv7 static void ggml_compute_forward_rwkv_wkv7_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..3fa1443abc 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 500cb6b72f..86b7e8741a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "GATED_DELTA_NET", "UNARY", @@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "gated_delta_net(q, k, v, g, beta, s)", "unary(x)", @@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6101,6 +6103,59 @@ struct ggml_tensor * ggml_solve_tri( return result; } +// ggml_gated_delta_net + +struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_k = q->ne[0]; + const int64_t H = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == H && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + GGML_ASSERT(g->ne[0] == H && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs); + GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); + + // concat output and new_state into a single tensor + // output: S_v * H * n_tokens, state: S_v * S_v * H * n_seqs + const int64_t ne[4] = { S_v * H, n_tokens + S_v * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_GATED_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { From 4be60a28b6c4f715c72f331bc7cdac5f88cf39ad Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 11 Feb 2026 11:11:44 +0530 Subject: [PATCH 2/9] tranpose --- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/ops.cpp | 63 ++++++++++++++++++++++++------------ 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f11d7eab8f..03bac78b8d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2913,7 +2913,7 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[0]->ne[0]; - cur = 4 * S_v * sizeof(float) * n_tasks; + cur = (S_v * S_v + 4 * S_v) * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 0069d8c0f5..c249e54772 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10322,15 +10322,17 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( GGML_ASSERT(ggml_is_contiguous(src_beta)); GGML_ASSERT(ggml_is_contiguous(src_state)); - // scratch layout per thread: [q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)] - const int64_t scratch_per_thread = 4 * S_v; + // scratch layout per thread: [s_t(S_v*S_v) | q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)] + // s_t holds the transposed (row-major) state for contiguous vector ops + const int64_t scratch_per_thread = S_v * S_v + 4 * S_v; const int ith = params->ith; float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; - float * q_local = scratch; - float * k_local = scratch + S_v; - float * kv_mem = scratch + 2 * S_v; - float * delta = scratch + 3 * S_v; + float * s_t = scratch; + float * q_local = scratch + S_v * S_v; + float * k_local = scratch + S_v * S_v + S_v; + float * kv_mem = scratch + S_v * S_v + 2 * S_v; + float * delta = scratch + S_v * S_v + 3 * S_v; // output layout: [attn_scores | new_states] // attn_scores: S_v * H * n_tokens * n_seqs floats @@ -10352,12 +10354,19 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t h_idx = ir % H; const int64_t sequence = ir / H; - // output state pointer for this (head, seq) + // output state pointer for this (head, seq) — column-major (ggml layout) float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v; - // copy input state for this (head, seq) into output + // Copy state into scratch in row-major layout of S (not S^T) + // ggml column-major: s_in[j + i*S_v] = S[j][i] (j=dim0, i=dim1) + // row-major of S: s_t[j * S_v + i] = S[j][i] (row j is contiguous over i) + // This makes kv_mem[j] = dot(s_t[j*S_v:], k) a contiguous dot product const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v; - memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + for (int64_t j = 0; j < S_v; ++j) { + for (int64_t i = 0; i < S_v; ++i) { + s_t[j * S_v + i] = s_in[j + i * S_v]; + } + } // attn output pointer for first token of this (head, seq) float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v; @@ -10390,30 +10399,42 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps)); // state decay: S *= exp(g) - ggml_vec_scale_f32(S_v * S_v, s_out, g_val); + // s_t is row-major, but scaling all elements is layout-agnostic + ggml_vec_scale_f32(S_v * S_v, s_t, g_val); - // kv_mem = S @ k - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_dot_f32(S_v, &kv_mem[i], 0, &s_out[i * S_v], 0, k_local, 0, 1); + // kv_mem[j] = sum_i S[j][i] * k[i] = dot(s_t[j*S_v:], k) + // row j of s_t is contiguous -> use ggml_vec_dot_f32 + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_dot_f32(S_v, &kv_mem[j], 0, &s_t[j * S_v], 0, k_local, 0, 1); } // delta = (v - kv_mem) * beta - for (int64_t i = 0; i < S_v; ++i) { - delta[i] = (v_d[i] - kv_mem[i]) * beta_val; + for (int64_t j = 0; j < S_v; ++j) { + delta[j] = (v_d[j] - kv_mem[j]) * beta_val; } - // outer product update: S += k (x) delta - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_local[i]); + // outer product: S[j][i] += k[i] * delta[j] + // s_t[j * S_v + i] += k[i] * delta[j] + // row j gets k[:] scaled by delta[j] -> contiguous ggml_vec_mad_f32 + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mad_f32(S_v, &s_t[j * S_v], k_local, delta[j]); } - // attn output = S @ q - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_dot_f32(S_v, &attn_data[i], 0, &s_out[i * S_v], 0, q_local, 0, 1); + // attn_out[j] = sum_i S[j][i] * q[i] = dot(s_t[j*S_v:], q) + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_dot_f32(S_v, &attn_data[j], 0, &s_t[j * S_v], 0, q_local, 0, 1); } attn_data += S_v * H; // advance to next token } + + // copy scratch back to output: row-major of S -> column-major (ggml layout) + // s_t[j * S_v + i] = S[j][i] -> s_out[j + i * S_v] = S[j][i] + for (int64_t j = 0; j < S_v; ++j) { + for (int64_t i = 0; i < S_v; ++i) { + s_out[j + i * S_v] = s_t[j * S_v + i]; + } + } } } From ffe3e82c8b51c3207dc6094c35c95a160987b4ce Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 11 Feb 2026 11:34:47 +0530 Subject: [PATCH 3/9] fix computation --- ggml/src/ggml-cpu/ops.cpp | 1 - ggml/src/ggml.c | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index c249e54772..74b3368480 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10399,7 +10399,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps)); // state decay: S *= exp(g) - // s_t is row-major, but scaling all elements is layout-agnostic ggml_vec_scale_f32(S_v * S_v, s_t, g_val); // kv_mem[j] = sum_i S[j][i] * k[i] = dot(s_t[j*S_v:], k) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 86b7e8741a..e7865b683a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6141,8 +6141,8 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); // concat output and new_state into a single tensor - // output: S_v * H * n_tokens, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens + S_v * n_seqs, 1, 1 }; + // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; From c7edcf22ec22346e5b08f695a9776b65e1e4ff2f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 11 Feb 2026 11:46:27 +0530 Subject: [PATCH 4/9] add eps --- ggml/include/ggml.h | 3 ++- ggml/src/ggml.c | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 82b27d2343..8edd675d41 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2474,7 +2474,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state); + struct ggml_tensor * state, + float eps); // custom operators diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e7865b683a..e25287b8a3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6112,7 +6112,8 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state) { + struct ggml_tensor * state, + float eps) { GGML_ASSERT(ggml_is_contiguous(q)); GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(v)); @@ -6145,6 +6146,8 @@ struct ggml_tensor * ggml_gated_delta_net( const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + ggml_set_op_params_f32(result, 0, eps); + result->op = GGML_OP_GATED_DELTA_NET; result->src[0] = q; result->src[1] = k; From 86833eb747d5d5b8be216f7a9ecd365eaecb4cfa Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 11 Feb 2026 11:47:54 +0530 Subject: [PATCH 5/9] repalce qwen3next --- src/models/qwen3next.cpp | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 99b1a76a48..886eb3d66f 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -781,15 +781,31 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) + // Choose between build_delta_net_chunking and fused ggml_gated_delta_net based on n_tokens + ggml_tensor * output; + ggml_tensor * new_state; if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + // Fused op expects state as [S_v*S_v*H, n_seqs] + ggml_tensor * state_2d = ggml_reshape_2d(ctx0, state, head_v_dim * head_v_dim * num_v_heads, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state_2d, + hparams.f_norm_rms_eps); + + // Unpack: attn scores then new state + const int64_t attn_elems = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; + const int64_t state_elems = head_v_dim * head_v_dim * num_v_heads * n_seqs; + + output = ggml_view_4d(ctx0, result, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + head_v_dim * sizeof(float), + head_v_dim * num_v_heads * sizeof(float), + head_v_dim * num_v_heads * n_seq_tokens * sizeof(float), + 0); + new_state = ggml_view_1d(ctx0, result, state_elems, attn_elems * sizeof(float)); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + std::pair attn_out; + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + output = attn_out.first; + new_state = attn_out.second; } - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); cb(new_state, "new_state", il); From 54ea12238539ab378e4969d5007ff5813eeabd24 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 11 Feb 2026 16:53:36 +0100 Subject: [PATCH 6/9] simplify impl + add cuda code --- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/ops.cpp | 52 ++--------- ggml/src/ggml-cuda/gated_delta_net.cu | 124 +++++++++++++++++++++++++ ggml/src/ggml-cuda/gated_delta_net.cuh | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 5 + src/models/qwen3next.cpp | 28 ++---- tests/test-backend-ops.cpp | 35 +++++++ 7 files changed, 185 insertions(+), 65 deletions(-) create mode 100644 ggml/src/ggml-cuda/gated_delta_net.cu create mode 100644 ggml/src/ggml-cuda/gated_delta_net.cuh diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 03bac78b8d..ac12e9c99f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2913,7 +2913,7 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[0]->ne[0]; - cur = (S_v * S_v + 4 * S_v) * sizeof(float) * n_tasks; + cur = (S_v * S_v + S_v) * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 74b3368480..fcc6dde15b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10322,17 +10322,14 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( GGML_ASSERT(ggml_is_contiguous(src_beta)); GGML_ASSERT(ggml_is_contiguous(src_state)); - // scratch layout per thread: [s_t(S_v*S_v) | q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)] + // scratch layout per thread: [s_t(S_v*S_v) | delta(S_v)] // s_t holds the transposed (row-major) state for contiguous vector ops - const int64_t scratch_per_thread = S_v * S_v + 4 * S_v; + const int64_t scratch_per_thread = S_v * S_v + S_v; const int ith = params->ith; float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; float * s_t = scratch; - float * q_local = scratch + S_v * S_v; - float * k_local = scratch + S_v * S_v + S_v; - float * kv_mem = scratch + S_v * S_v + 2 * S_v; - float * delta = scratch + S_v * S_v + 3 * S_v; + float * delta = scratch + S_v * S_v; // output layout: [attn_scores | new_states] // attn_scores: S_v * H * n_tokens * n_seqs floats @@ -10348,19 +10345,13 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * g_base = (const float *)src_g->data; const float * beta_base = (const float *)src_beta->data; - const float eps = ggml_get_op_params_f32(dst, 0); - for (int64_t ir = ir0; ir < ir1; ++ir) { const int64_t h_idx = ir % H; const int64_t sequence = ir / H; - // output state pointer for this (head, seq) — column-major (ggml layout) float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v; - // Copy state into scratch in row-major layout of S (not S^T) - // ggml column-major: s_in[j + i*S_v] = S[j][i] (j=dim0, i=dim1) - // row-major of S: s_t[j * S_v + i] = S[j][i] (row j is contiguous over i) - // This makes kv_mem[j] = dot(s_t[j*S_v:], k) a contiguous dot product + // tranpose const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v; for (int64_t j = 0; j < S_v; ++j) { for (int64_t i = 0; i < S_v; ++i) { @@ -10379,56 +10370,33 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * k_d = k_base + qkv_offset; const float * v_d = v_base + qkv_offset; - // g and beta layout: [H, n_tokens, n_seqs] const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; const float beta_val_raw = beta_base[gb_offset]; const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid const float g_val = expf(g_base[gb_offset]); - memcpy(q_local, q_d, S_v * sizeof(float)); - memcpy(k_local, k_d, S_v * sizeof(float)); - - // l2-norm q and scale by 1/sqrt(S_v) - float norm; - ggml_vec_norm_f32(S_v, &norm, q_local); - ggml_vec_scale_f32(S_v, q_local, 1.0f / fmaxf(norm, eps)); - ggml_vec_scale_f32(S_v, q_local, 1.0f / sqrtf((float)S_v)); - - // l2-norm k - ggml_vec_norm_f32(S_v, &norm, k_local); - ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps)); - - // state decay: S *= exp(g) ggml_vec_scale_f32(S_v * S_v, s_t, g_val); - // kv_mem[j] = sum_i S[j][i] * k[i] = dot(s_t[j*S_v:], k) - // row j of s_t is contiguous -> use ggml_vec_dot_f32 for (int64_t j = 0; j < S_v; ++j) { - ggml_vec_dot_f32(S_v, &kv_mem[j], 0, &s_t[j * S_v], 0, k_local, 0, 1); - } - - // delta = (v - kv_mem) * beta - for (int64_t j = 0; j < S_v; ++j) { - delta[j] = (v_d[j] - kv_mem[j]) * beta_val; + float kv_j; + ggml_vec_dot_f32(S_v, &kv_j, 0, &s_t[j * S_v], 0, k_d, 0, 1); + delta[j] = (v_d[j] - kv_j) * beta_val; } // outer product: S[j][i] += k[i] * delta[j] - // s_t[j * S_v + i] += k[i] * delta[j] - // row j gets k[:] scaled by delta[j] -> contiguous ggml_vec_mad_f32 for (int64_t j = 0; j < S_v; ++j) { - ggml_vec_mad_f32(S_v, &s_t[j * S_v], k_local, delta[j]); + ggml_vec_mad_f32(S_v, &s_t[j * S_v], k_d, delta[j]); } // attn_out[j] = sum_i S[j][i] * q[i] = dot(s_t[j*S_v:], q) for (int64_t j = 0; j < S_v; ++j) { - ggml_vec_dot_f32(S_v, &attn_data[j], 0, &s_t[j * S_v], 0, q_local, 0, 1); + ggml_vec_dot_f32(S_v, &attn_data[j], 0, &s_t[j * S_v], 0, q_d, 0, 1); } attn_data += S_v * H; // advance to next token } - // copy scratch back to output: row-major of S -> column-major (ggml layout) - // s_t[j * S_v + i] = S[j][i] -> s_out[j + i * S_v] = S[j][i] + // transpose back for (int64_t j = 0; j < S_v; ++j) { for (int64_t i = 0; i < S_v; ++i) { s_out[j + i * S_v] = s_t[j * S_v + i]; diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu new file mode 100644 index 0000000000..88dc6c65d7 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -0,0 +1,124 @@ +#include "ggml-cuda/common.cuh" +#include "gated_delta_net.cuh" + +template +__global__ void gated_delta_net_cuda( + const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs +) { + const int64_t h_idx = blockIdx.x; + const int64_t sequence = blockIdx.y; + const int col = threadIdx.x; // each thread owns one column + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; + state += state_offset; + curr_state += state_offset; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + // Copy input state to output state (working area) +#pragma unroll + for (int i = 0; i < S_v; i++) { + state[i * S_v + col] = curr_state[i * S_v + col]; + } + + for (int t = 0; t < n_tokens; t++) { + const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v; + const float * q_t = q + qkv_offset; + const float * k_t = k + qkv_offset; + const float * v_t = v + qkv_offset; + + const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; + const float beta_val = 1.0f / (1.0f + expf(-beta[gb_offset])); + const float g_val = expf(g[gb_offset]); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + kv_col += state[i * S_v + col] * k_t[i]; + } + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] +#pragma unroll + for (int i = 0; i < S_v; i++) { + state[i * S_v + col] = g_val * state[i * S_v + col] + k_t[i] * delta_col; + } + + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + attn_col += state[i * S_v + col] * q_t[i]; + } + attn_data[col] = attn_col; + attn_data += S_v * H; + } +} + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, + ggml_tensor * dst) { + + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_q->ne[0]; + const int64_t H = src_q->ne[1]; + const int64_t n_tokens = src_q->ne[2]; + const int64_t n_seqs = src_q->ne[3]; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + + const float * b_d = (const float *) src_beta->data; + const float * s_d = (const float *) src_state->data; + + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(src_q)); + GGML_ASSERT(ggml_is_contiguous(src_k)); + GGML_ASSERT(ggml_is_contiguous(src_v)); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + dim3 grid_dims(H, n_seqs, 1); + dim3 block_dims(S_v, 1, 1); + + cudaStream_t stream = ctx.stream(); + + switch(S_v) { + case 32: + gated_delta_net_cuda<32><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + break; + case 64: + gated_delta_net_cuda<64><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + break; + case 128: + gated_delta_net_cuda<128><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh new file mode 100644 index 0000000000..7375e81c0c --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" +#include "ggml.h" + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b163468789..4ec490393a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -53,6 +53,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/gated_delta_net.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" @@ -2730,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_cuda_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_cuda_op_gated_delta_net(ctx, dst); + break; case GGML_OP_RWKV_WKV7: ggml_cuda_op_rwkv_wkv7(ctx, dst); break; @@ -4844,6 +4848,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_GATED_DELTA_NET: case GGML_OP_RWKV_WKV7: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 886eb3d66f..99b1a76a48 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -781,31 +781,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking and fused ggml_gated_delta_net based on n_tokens - ggml_tensor * output; - ggml_tensor * new_state; + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + std::pair attn_out; // pair of (output, new_state) if (n_seq_tokens == 1) { - // Fused op expects state as [S_v*S_v*H, n_seqs] - ggml_tensor * state_2d = ggml_reshape_2d(ctx0, state, head_v_dim * head_v_dim * num_v_heads, n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state_2d, - hparams.f_norm_rms_eps); - - // Unpack: attn scores then new state - const int64_t attn_elems = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; - const int64_t state_elems = head_v_dim * head_v_dim * num_v_heads * n_seqs; - - output = ggml_view_4d(ctx0, result, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, - head_v_dim * sizeof(float), - head_v_dim * num_v_heads * sizeof(float), - head_v_dim * num_v_heads * n_seq_tokens * sizeof(float), - 0); - new_state = ggml_view_1d(ctx0, result, state_elems, attn_elems * sizeof(float)); + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - std::pair attn_out; - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); - output = attn_out.first; - new_state = attn_out.second; + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); } + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); cb(new_state, "new_state", il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3bdb3d7ba9..802e8a5d78 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3635,6 +3635,35 @@ struct test_rwkv_wkv6 : public test_case { } }; +// GGML_OP_GATED_DELTA_NET +struct test_gated_delta_net : public test_case { + const ggml_type type; + + const int64_t head_count; + const int64_t head_size; + const int64_t n_seq_tokens; + const int64_t n_seqs; + + std::string vars() override { + return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs); + } + + test_gated_delta_net(ggml_type type = GGML_TYPE_F32, + int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1) + : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); + ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); + ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); + ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs); + ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * head_size * head_count, n_seqs); + ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state, 1e-6f); + return out; + } +}; + // GGML_OP_GATED_LINEAR_ATTN struct test_gla : public test_case { const ggml_type type; @@ -8310,6 +8339,12 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2)); + #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true)); From 15d83e0c87530b1166acec0550f6320aa0fa2461 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 12 Feb 2026 21:51:51 +0530 Subject: [PATCH 7/9] cpu: support for non-contig q,k,v --- ggml/include/ggml.h | 3 +- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/ops.cpp | 68 +++++++++++++++++++++++------------- ggml/src/ggml.c | 19 +++------- tests/test-backend-ops.cpp | 19 +++++----- 5 files changed, 60 insertions(+), 51 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8edd675d41..82b27d2343 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2474,8 +2474,7 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state, - float eps); + struct ggml_tensor * state); // custom operators diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ac12e9c99f..358af8c53e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2912,7 +2912,7 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_GATED_DELTA_NET: { - const int64_t S_v = node->src[0]->ne[0]; + const int64_t S_v = node->src[2]->ne[0]; cur = (S_v * S_v + S_v) * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fcc6dde15b..31c550968e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10310,22 +10310,35 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( ggml_tensor * src_beta = dst->src[4]; ggml_tensor * src_state = dst->src[5]; - const int64_t S_v = src_q->ne[0]; - const int64_t H = src_q->ne[1]; - const int64_t n_tokens = src_q->ne[2]; - const int64_t n_seqs = src_q->ne[3]; + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; - GGML_ASSERT(ggml_is_contiguous(src_q)); - GGML_ASSERT(ggml_is_contiguous(src_k)); - GGML_ASSERT(ggml_is_contiguous(src_v)); + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); GGML_ASSERT(ggml_is_contiguous(src_g)); GGML_ASSERT(ggml_is_contiguous(src_beta)); GGML_ASSERT(ggml_is_contiguous(src_state)); + // TODO: to support KDA + GGML_ASSERT(ggml_are_same_shape(src_beta, src_g)); + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + // scratch layout per thread: [s_t(S_v*S_v) | delta(S_v)] // s_t holds the transposed (row-major) state for contiguous vector ops const int64_t scratch_per_thread = S_v * S_v + S_v; const int ith = params->ith; + float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; float * s_t = scratch; @@ -10339,20 +10352,28 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( float * state_out_base = (float *)dst->data + attn_score_elems; const float * state_in_base = (const float *)src_state->data; - const float * q_base = (const float *)src_q->data; - const float * k_base = (const float *)src_k->data; - const float * v_base = (const float *)src_v->data; const float * g_base = (const float *)src_g->data; const float * beta_base = (const float *)src_beta->data; - for (int64_t ir = ir0; ir < ir1; ++ir) { - const int64_t h_idx = ir % H; - const int64_t sequence = ir / H; + const int64_t rq1 = nev1 / neq1; + const int64_t rk1 = nev1 / nek1; + const int64_t rq3 = nev3 / neq3; + const int64_t rk3 = nev3 / nek3; - float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v; + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t iv1 = ir % H; // head_index + const int64_t iv3 = ir / H; // sequence + + const int64_t iq1 = iv1 / rq1; + const int64_t ik1 = iv1 / rk1; + + const int64_t iq3 = iv3 / rq3; + const int64_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; // tranpose - const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v; + const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; for (int64_t j = 0; j < S_v; ++j) { for (int64_t i = 0; i < S_v; ++i) { s_t[j * S_v + i] = s_in[j + i * S_v]; @@ -10360,17 +10381,14 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } // attn output pointer for first token of this (head, seq) - float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v; + float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; for (int64_t t = 0; t < n_tokens; t++) { - // input pointers for this (head, seq, token) - // layout is contiguous [S_v, H, n_tokens, n_seqs] - const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v; - const float * q_d = q_base + qkv_offset; - const float * k_d = k_base + qkv_offset; - const float * v_d = v_base + qkv_offset; + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); + const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); + const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); - const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; + const int64_t gb_offset = iv3 * neg1 * neg0 + t * neg0 + iv1; const float beta_val_raw = beta_base[gb_offset]; const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid const float g_val = expf(g_base[gb_offset]); @@ -10410,8 +10428,8 @@ static void ggml_compute_forward_gated_delta_net_f32( const ggml_compute_params * params, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - int64_t nr = Q->ne[1] * Q->ne[3]; + ggml_tensor * V = dst->src[2]; + int64_t nr = V->ne[1] * V->ne[3]; // disable for NUMA const bool disable_chunking = ggml_is_numa(); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e25287b8a3..abaf94eb43 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6112,8 +6112,7 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state, - float eps) { + struct ggml_tensor * state) { GGML_ASSERT(ggml_is_contiguous(q)); GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(v)); @@ -6128,17 +6127,11 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(beta->type == GGML_TYPE_F32); GGML_ASSERT(state->type == GGML_TYPE_F32); - const int64_t S_k = q->ne[0]; - const int64_t H = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; - const int64_t S_v = v->ne[0]; - - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == H && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - GGML_ASSERT(g->ne[0] == H && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs); GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); // concat output and new_state into a single tensor @@ -6146,8 +6139,6 @@ struct ggml_tensor * ggml_gated_delta_net( const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_set_op_params_f32(result, 0, eps); - result->op = GGML_OP_GATED_DELTA_NET; result->src[0] = q; result->src[1] = k; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 802e8a5d78..824aa8a0c2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3643,23 +3643,24 @@ struct test_gated_delta_net : public test_case { const int64_t head_size; const int64_t n_seq_tokens; const int64_t n_seqs; + const int v_repeat; std::string vars() override { - return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs); + return VARS_TO_STR6(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat); } test_gated_delta_net(ggml_type type = GGML_TYPE_F32, - int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1) - : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, int v_repeat = 1) + : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), v_repeat(v_repeat) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); - ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); - ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs); - ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs); - ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * head_size * head_count, n_seqs); - ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state, 1e-6f); + ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs); + ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state); return out; } }; @@ -8343,7 +8344,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2)); - test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2)); #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging From 2f0ac21d4ba85b809860fce77dc10a807127c080 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Feb 2026 14:12:09 +0100 Subject: [PATCH 8/9] cuda: add support for non-contig q,k,v --- ggml/src/ggml-cuda/gated_delta_net.cu | 122 +++++++++++++++++--------- 1 file changed, 79 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 88dc6c65d7..6f48e1e467 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,31 +1,42 @@ -#include "ggml-cuda/common.cuh" #include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" -template -__global__ void gated_delta_net_cuda( - const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - int64_t H, - int64_t n_tokens, - int64_t n_seqs -) { +template +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sg1, + int64_t sg2, + int64_t rq1, + int64_t rq3) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; - const int col = threadIdx.x; // each thread owns one column + const int col = threadIdx.x; // each thread owns one column + + const int64_t iq1 = h_idx / rq1; + const int64_t iq3 = sequence / rq3; const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; - float * attn_data = dst; - float * state = dst + attn_score_elems; + float * attn_data = dst; + float * state = dst + attn_score_elems; const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; + state += state_offset; curr_state += state_offset; - attn_data += (sequence * n_tokens * H + h_idx) * S_v; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; // Copy input state to output state (working area) #pragma unroll @@ -34,14 +45,15 @@ __global__ void gated_delta_net_cuda( } for (int t = 0; t < n_tokens; t++) { - const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v; - const float * q_t = q + qkv_offset; - const float * k_t = k + qkv_offset; - const float * v_t = v + qkv_offset; + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; - const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; - const float beta_val = 1.0f / (1.0f + expf(-beta[gb_offset])); - const float g_val = expf(g[gb_offset]); + const float * g_t = g + sequence * sg2 + t * sg1; + const float * beta_t = beta + sequence * sg2 + t * sg1; + + const float beta_val = 1.0f / (1.0f + expf(-beta_t[h_idx])); + const float g_val = expf(g_t[h_idx]); // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] float kv_col = 0.0f; @@ -70,9 +82,7 @@ __global__ void gated_delta_net_cuda( } } -void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, - ggml_tensor * dst) { - +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * src_q = dst->src[0]; ggml_tensor * src_k = dst->src[1]; ggml_tensor * src_v = dst->src[2]; @@ -80,42 +90,68 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * src_beta = dst->src[4]; ggml_tensor * src_state = dst->src[5]; - const int64_t S_v = src_q->ne[0]; - const int64_t H = src_q->ne[1]; - const int64_t n_tokens = src_q->ne[2]; - const int64_t n_seqs = src_q->ne[3]; + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const int64_t rq1 = nev1 / neq1; + const int64_t rq3 = nev3 / neq3; const float * q_d = (const float *) src_q->data; const float * k_d = (const float *) src_k->data; const float * v_d = (const float *) src_v->data; const float * g_d = (const float *) src_g->data; - const float * b_d = (const float *) src_beta->data; - const float * s_d = (const float *) src_state->data; - float * dst_d = (float *) dst->data; + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; - GGML_ASSERT(ggml_is_contiguous(src_q)); - GGML_ASSERT(ggml_is_contiguous(src_k)); - GGML_ASSERT(ggml_is_contiguous(src_v)); + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(ggml_are_same_stride(src_g, src_beta)); GGML_ASSERT(ggml_is_contiguous(src_g)); GGML_ASSERT(ggml_is_contiguous(src_beta)); GGML_ASSERT(ggml_is_contiguous(src_state)); + // strides in floats + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sg1 = nbg1 / sizeof(float); + const int64_t sg2 = nbg2 / sizeof(float); + dim3 grid_dims(H, n_seqs, 1); dim3 block_dims(S_v, 1, 1); cudaStream_t stream = ctx.stream(); - switch(S_v) { + switch (S_v) { case 32: - gated_delta_net_cuda<32><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + gated_delta_net_cuda<32><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, + sv3, sg1, sg2, rq1, rq3); break; case 64: - gated_delta_net_cuda<64><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + gated_delta_net_cuda<64><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, + sv3, sg1, sg2, rq1, rq3); break; case 128: - gated_delta_net_cuda<128><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + gated_delta_net_cuda<128><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, + sv3, sg1, sg2, rq1, rq3); break; default: GGML_ABORT("fatal error"); From 3db6e5ef224eef9e550942f1fab4704307e84b1f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Feb 2026 14:18:41 +0100 Subject: [PATCH 9/9] add permuted test-case --- ggml/src/ggml.c | 6 +++--- tests/test-backend-ops.cpp | 26 ++++++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index abaf94eb43..1fbf719a70 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6113,9 +6113,9 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * g, struct ggml_tensor * beta, struct ggml_tensor * state) { - GGML_ASSERT(ggml_is_contiguous(q)); - GGML_ASSERT(ggml_is_contiguous(k)); - GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous_rows(q)); + GGML_ASSERT(ggml_is_contiguous_rows(k)); + GGML_ASSERT(ggml_is_contiguous_rows(v)); GGML_ASSERT(ggml_is_contiguous(g)); GGML_ASSERT(ggml_is_contiguous(beta)); GGML_ASSERT(ggml_is_contiguous(state)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 824aa8a0c2..5dd69f1d70 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3646,17 +3646,29 @@ struct test_gated_delta_net : public test_case { const int v_repeat; std::string vars() override { - return VARS_TO_STR6(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat); + return VARS_TO_STR7(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted); } + const bool permuted; + test_gated_delta_net(ggml_type type = GGML_TYPE_F32, - int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, int v_repeat = 1) - : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), v_repeat(v_repeat) {} + int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, int v_repeat = 1, bool permuted = false) + : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), v_repeat(v_repeat), permuted(permuted) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); - ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); - ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * q; + ggml_tensor * k; + ggml_tensor * v; + if (permuted) { + // create with dims 1 and 2 swapped, then permute back to get non-contiguous layout + q = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3); + k = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3); + v = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count * v_repeat, n_seqs), 0, 2, 1, 3); + } else { + q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); + k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); + v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs); + } ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs); ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs); ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs); @@ -8345,6 +8357,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1, 1, true)); #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging