From 2f0ac21d4ba85b809860fce77dc10a807127c080 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Feb 2026 14:12:09 +0100 Subject: [PATCH] cuda: add support for non-contig q,k,v --- ggml/src/ggml-cuda/gated_delta_net.cu | 122 +++++++++++++++++--------- 1 file changed, 79 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 88dc6c65d7..6f48e1e467 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,31 +1,42 @@ -#include "ggml-cuda/common.cuh" #include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" -template -__global__ void gated_delta_net_cuda( - const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - int64_t H, - int64_t n_tokens, - int64_t n_seqs -) { +template +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sg1, + int64_t sg2, + int64_t rq1, + int64_t rq3) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; - const int col = threadIdx.x; // each thread owns one column + const int col = threadIdx.x; // each thread owns one column + + const int64_t iq1 = h_idx / rq1; + const int64_t iq3 = sequence / rq3; const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; - float * attn_data = dst; - float * state = dst + attn_score_elems; + float * attn_data = dst; + float * state = dst + attn_score_elems; const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; + state += state_offset; curr_state += state_offset; - attn_data += (sequence * n_tokens * H + h_idx) * S_v; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; // Copy input state to output state (working area) #pragma unroll @@ -34,14 +45,15 @@ __global__ void gated_delta_net_cuda( } for (int t = 0; t < n_tokens; t++) { - const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v; - const float * q_t = q + qkv_offset; - const float * k_t = k + qkv_offset; - const float * v_t = v + qkv_offset; + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; - const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx; - const float beta_val = 1.0f / (1.0f + expf(-beta[gb_offset])); - const float g_val = expf(g[gb_offset]); + const float * g_t = g + sequence * sg2 + t * sg1; + const float * beta_t = beta + sequence * sg2 + t * sg1; + + const float beta_val = 1.0f / (1.0f + expf(-beta_t[h_idx])); + const float g_val = expf(g_t[h_idx]); // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] float kv_col = 0.0f; @@ -70,9 +82,7 @@ __global__ void gated_delta_net_cuda( } } -void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, - ggml_tensor * dst) { - +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * src_q = dst->src[0]; ggml_tensor * src_k = dst->src[1]; ggml_tensor * src_v = dst->src[2]; @@ -80,42 +90,68 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * src_beta = dst->src[4]; ggml_tensor * src_state = dst->src[5]; - const int64_t S_v = src_q->ne[0]; - const int64_t H = src_q->ne[1]; - const int64_t n_tokens = src_q->ne[2]; - const int64_t n_seqs = src_q->ne[3]; + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const int64_t rq1 = nev1 / neq1; + const int64_t rq3 = nev3 / neq3; const float * q_d = (const float *) src_q->data; const float * k_d = (const float *) src_k->data; const float * v_d = (const float *) src_v->data; const float * g_d = (const float *) src_g->data; - const float * b_d = (const float *) src_beta->data; - const float * s_d = (const float *) src_state->data; - float * dst_d = (float *) dst->data; + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; - GGML_ASSERT(ggml_is_contiguous(src_q)); - GGML_ASSERT(ggml_is_contiguous(src_k)); - GGML_ASSERT(ggml_is_contiguous(src_v)); + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(ggml_are_same_stride(src_g, src_beta)); GGML_ASSERT(ggml_is_contiguous(src_g)); GGML_ASSERT(ggml_is_contiguous(src_beta)); GGML_ASSERT(ggml_is_contiguous(src_state)); + // strides in floats + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sg1 = nbg1 / sizeof(float); + const int64_t sg2 = nbg2 / sizeof(float); + dim3 grid_dims(H, n_seqs, 1); dim3 block_dims(S_v, 1, 1); cudaStream_t stream = ctx.stream(); - switch(S_v) { + switch (S_v) { case 32: - gated_delta_net_cuda<32><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + gated_delta_net_cuda<32><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, + sv3, sg1, sg2, rq1, rq3); break; case 64: - gated_delta_net_cuda<64><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + gated_delta_net_cuda<64><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, + sv3, sg1, sg2, rq1, rq3); break; case 128: - gated_delta_net_cuda<128><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs); + gated_delta_net_cuda<128><<>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, + sv3, sg1, sg2, rq1, rq3); break; default: GGML_ABORT("fatal error");