From 15d83e0c87530b1166acec0550f6320aa0fa2461 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 12 Feb 2026 21:51:51 +0530 Subject: [PATCH] cpu: support for non-contig q,k,v --- ggml/include/ggml.h | 3 +- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/ops.cpp | 68 +++++++++++++++++++++++------------- ggml/src/ggml.c | 19 +++------- tests/test-backend-ops.cpp | 19 +++++----- 5 files changed, 60 insertions(+), 51 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8edd675d41..82b27d2343 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2474,8 +2474,7 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state, - float eps); + struct ggml_tensor * state); // custom operators diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ac12e9c99f..358af8c53e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2912,7 +2912,7 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_GATED_DELTA_NET: { - const int64_t S_v = node->src[0]->ne[0]; + const int64_t S_v = node->src[2]->ne[0]; cur = (S_v * S_v + S_v) * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fcc6dde15b..31c550968e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10310,22 +10310,35 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( ggml_tensor * src_beta = dst->src[4]; ggml_tensor * src_state = dst->src[5]; - const int64_t S_v = src_q->ne[0]; - const int64_t H = src_q->ne[1]; - const int64_t n_tokens = src_q->ne[2]; - const int64_t n_seqs = src_q->ne[3]; + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; - GGML_ASSERT(ggml_is_contiguous(src_q)); - GGML_ASSERT(ggml_is_contiguous(src_k)); - GGML_ASSERT(ggml_is_contiguous(src_v)); + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); GGML_ASSERT(ggml_is_contiguous(src_g)); GGML_ASSERT(ggml_is_contiguous(src_beta)); GGML_ASSERT(ggml_is_contiguous(src_state)); + // TODO: to support KDA + GGML_ASSERT(ggml_are_same_shape(src_beta, src_g)); + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + // scratch layout per thread: [s_t(S_v*S_v) | delta(S_v)] // s_t holds the transposed (row-major) state for contiguous vector ops const int64_t scratch_per_thread = S_v * S_v + S_v; const int ith = params->ith; + float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; float * s_t = scratch; @@ -10339,20 +10352,28 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( float * state_out_base = (float *)dst->data + attn_score_elems; const float * state_in_base = (const float *)src_state->data; - const float * q_base = (const float *)src_q->data; - const float * k_base = (const float *)src_k->data; - const float * v_base = (const float *)src_v->data; const float * g_base = (const float *)src_g->data; const float * beta_base = (const float *)src_beta->data; - for (int64_t ir = ir0; ir < ir1; ++ir) { - const int64_t h_idx = ir % H; - const int64_t sequence = ir / H; + const int64_t rq1 = nev1 / neq1; + const int64_t rk1 = nev1 / nek1; + const int64_t rq3 = nev3 / neq3; + const int64_t rk3 = nev3 / nek3; - float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v; + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t iv1 = ir % H; // head_index + const int64_t iv3 = ir / H; // sequence + + const int64_t iq1 = iv1 / rq1; + const int64_t ik1 = iv1 / rk1; + + const int64_t iq3 = iv3 / rq3; + const int64_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; // tranpose - const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v; + const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; for (int64_t j = 0; j < S_v; ++j) { for (int64_t i = 0; i < S_v; ++i) { s_t[j * S_v + i] = s_in[j + i * S_v]; @@ -10360,17 +10381,14 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } // attn output pointer for first token of this (head, seq) - float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v; + float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; for (int64_t t = 0; t < n_tokens; t++) { - // input pointers for this (head, seq, token) - // layout is contiguous [S_v, H, n_tokens, n_seqs] - const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v; - const float * q_d = q_base + qkv_offset; - const float * k_d = k_base + qkv_offset; - const float * v_d = v_base + qkv_offset; + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); + const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); + const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); - const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; + const int64_t gb_offset = iv3 * neg1 * neg0 + t * neg0 + iv1; const float beta_val_raw = beta_base[gb_offset]; const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid const float g_val = expf(g_base[gb_offset]); @@ -10410,8 +10428,8 @@ static void ggml_compute_forward_gated_delta_net_f32( const ggml_compute_params * params, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - int64_t nr = Q->ne[1] * Q->ne[3]; + ggml_tensor * V = dst->src[2]; + int64_t nr = V->ne[1] * V->ne[3]; // disable for NUMA const bool disable_chunking = ggml_is_numa(); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e25287b8a3..abaf94eb43 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6112,8 +6112,7 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state, - float eps) { + struct ggml_tensor * state) { GGML_ASSERT(ggml_is_contiguous(q)); GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(v)); @@ -6128,17 +6127,11 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(beta->type == GGML_TYPE_F32); GGML_ASSERT(state->type == GGML_TYPE_F32); - const int64_t S_k = q->ne[0]; - const int64_t H = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; - const int64_t S_v = v->ne[0]; - - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == H && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - GGML_ASSERT(g->ne[0] == H && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs); GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); // concat output and new_state into a single tensor @@ -6146,8 +6139,6 @@ struct ggml_tensor * ggml_gated_delta_net( const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_set_op_params_f32(result, 0, eps); - result->op = GGML_OP_GATED_DELTA_NET; result->src[0] = q; result->src[1] = k; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 802e8a5d78..824aa8a0c2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3643,23 +3643,24 @@ struct test_gated_delta_net : public test_case { const int64_t head_size; const int64_t n_seq_tokens; const int64_t n_seqs; + const int v_repeat; std::string vars() override { - return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs); + return VARS_TO_STR6(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat); } test_gated_delta_net(ggml_type type = GGML_TYPE_F32, - int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1) - : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, int v_repeat = 1) + : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), v_repeat(v_repeat) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); - ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs); - ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs); - ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs); - ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * head_size * head_count, n_seqs); - ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state, 1e-6f); + ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs); + ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state); return out; } }; @@ -8343,7 +8344,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2)); - test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2)); #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging