cuda: add support for non-contig q,k,v

This commit is contained in:
Aman Gupta 2026-02-13 14:12:09 +01:00
parent 15d83e0c87
commit 2f0ac21d4b
1 changed files with 79 additions and 43 deletions

View File

@ -1,31 +1,42 @@
#include "ggml-cuda/common.cuh"
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"
template<int S_v>
__global__ void gated_delta_net_cuda(
const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs
) {
template <int S_v>
__global__ void gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sg1,
int64_t sg2,
int64_t rq1,
int64_t rq3) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
const int col = threadIdx.x; // each thread owns one column
const int64_t iq1 = h_idx / rq1;
const int64_t iq3 = sequence / rq3;
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
float * state = dst + attn_score_elems;
float * attn_data = dst;
float * state = dst + attn_score_elems;
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
state += state_offset;
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
// Copy input state to output state (working area)
#pragma unroll
@ -34,14 +45,15 @@ __global__ void gated_delta_net_cuda(
}
for (int t = 0; t < n_tokens; t++) {
const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v;
const float * q_t = q + qkv_offset;
const float * k_t = k + qkv_offset;
const float * v_t = v + qkv_offset;
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx;
const float beta_val = 1.0f / (1.0f + expf(-beta[gb_offset]));
const float g_val = expf(g[gb_offset]);
const float * g_t = g + sequence * sg2 + t * sg1;
const float * beta_t = beta + sequence * sg2 + t * sg1;
const float beta_val = 1.0f / (1.0f + expf(-beta_t[h_idx]));
const float g_val = expf(g_t[h_idx]);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0f;
@ -70,9 +82,7 @@ __global__ void gated_delta_net_cuda(
}
}
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx,
ggml_tensor * dst) {
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
@ -80,42 +90,68 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx,
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
const int64_t S_v = src_q->ne[0];
const int64_t H = src_q->ne[1];
const int64_t n_tokens = src_q->ne[2];
const int64_t n_seqs = src_q->ne[3];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
const int64_t n_tokens = nev2;
const int64_t n_seqs = nev3;
const int64_t rq1 = nev1 / neq1;
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
const float * k_d = (const float *) src_k->data;
const float * v_d = (const float *) src_v->data;
const float * g_d = (const float *) src_g->data;
const float * b_d = (const float *) src_beta->data;
const float * s_d = (const float *) src_state->data;
float * dst_d = (float *) dst->data;
const float * s_d = (const float *) src_state->data;
float * dst_d = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous(src_q));
GGML_ASSERT(ggml_is_contiguous(src_k));
GGML_ASSERT(ggml_is_contiguous(src_v));
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
GGML_ASSERT(ggml_are_same_stride(src_g, src_beta));
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
// strides in floats
const int64_t sq1 = nbq1 / sizeof(float);
const int64_t sq2 = nbq2 / sizeof(float);
const int64_t sq3 = nbq3 / sizeof(float);
const int64_t sv1 = nbv1 / sizeof(float);
const int64_t sv2 = nbv2 / sizeof(float);
const int64_t sv3 = nbv3 / sizeof(float);
const int64_t sg1 = nbg1 / sizeof(float);
const int64_t sg2 = nbg2 / sizeof(float);
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
cudaStream_t stream = ctx.stream();
switch(S_v) {
switch (S_v) {
case 32:
gated_delta_net_cuda<32><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
gated_delta_net_cuda<32><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
sv3, sg1, sg2, rq1, rq3);
break;
case 64:
gated_delta_net_cuda<64><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
gated_delta_net_cuda<64><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
sv3, sg1, sg2, rq1, rq3);
break;
case 128:
gated_delta_net_cuda<128><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
gated_delta_net_cuda<128><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
sv3, sg1, sg2, rq1, rq3);
break;
default:
GGML_ABORT("fatal error");