From 4114537fb9f19850775437d82eb092507d830683 Mon Sep 17 00:00:00 2001 From: hauhaut Date: Tue, 16 Dec 2025 02:08:41 +0100 Subject: [PATCH 1/3] ggml-cuda: Delta-Net linear attention for Qwen3-Next --- ggml/include/ggml.h | 10 + ggml/src/ggml-cuda/delta-net.cu | 1611 ++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/delta-net.cuh | 3 + ggml/src/ggml-cuda/ggml-cuda.cu | 13 + ggml/src/ggml-cuda/solve_tri.cu | 867 +++++++++++++--- ggml/src/ggml.c | 63 +- src/models/models.h | 4 +- src/models/qwen3next.cpp | 82 +- 8 files changed, 2509 insertions(+), 144 deletions(-) create mode 100644 ggml/src/ggml-cuda/delta-net.cu create mode 100644 ggml/src/ggml-cuda/delta-net.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 20c912d0e9..609c880ff1 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -551,6 +551,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_DELTA_NET, GGML_OP_UNARY, @@ -2460,6 +2461,15 @@ extern "C" { bool lower, bool uni); + GGML_API struct ggml_tensor * ggml_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, // [S_k, n_tokens, H_k, n_seqs] - Query (pre-permuted) + struct ggml_tensor * k, // [S_k, n_tokens, H_k, n_seqs] - Key (pre-permuted) + struct ggml_tensor * v, // [S_v, n_tokens, H_v, n_seqs] - Value (pre-permuted) + struct ggml_tensor * g, // [n_tokens, 1, H_k, n_seqs] - Gate logits (pre-permuted) + struct ggml_tensor * beta, // [1, n_tokens, H_k, n_seqs] - Beta (pre-permuted) + struct ggml_tensor * state); // [S_v, S_v*H_v, 1, n_seqs] - Recurrent state + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu new file mode 100644 index 0000000000..634780d5dc --- /dev/null +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -0,0 +1,1611 @@ +#include "common.cuh" +#include "delta-net.cuh" +#include + +// Delta Net Linear Attention Kernel for Qwen3-Next (HEAD_DIM=128) +// State layout: [S_v, S_v*H_v, 1, n_seqs] (column-major) + +__device__ __forceinline__ float sigmoid_f(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +// Token-by-token recurrent kernel +// One block per (batch, head) pair, processes all tokens sequentially +// State is kept in global memory (too large for shared memory at HEAD_DIM=128) +template +__global__ void delta_net_recurrent_f32( + const float * __restrict__ q, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ k, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ v, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ g, // [n_tokens, 1, n_heads, n_seqs] + const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs] + const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] + float * __restrict__ dst, // output + new_state concatenated + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, // offset where state starts in output + const float eps) +{ + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads + const int lane_id = tid % WARP_SIZE; // 0-31 + constexpr int NUM_WARPS = 8; // 256 / 32 + + // Strides for input tensors (column-major) + // Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs] + const int64_t qkv_stride_token = HEAD_DIM; + const int64_t qkv_stride_head = HEAD_DIM * n_tokens; + const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; + + // G/Beta: [n_tokens, 1, n_heads, n_seqs] / [1, n_tokens, n_heads, n_seqs] + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + + // State: [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] + // For head h: columns h*HEAD_DIM to (h+1)*HEAD_DIM + // state[row, col] for head h = state[row, h*HEAD_DIM + col] + // Linear index: row + (h*HEAD_DIM + col) * HEAD_DIM = row + h*HEAD_DIM^2 + col*HEAD_DIM + const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; + const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + + // Pointers for this batch/head + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + + // Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs] + // For [dim, head, token, batch]: index = dim + head*S_v + token*S_v*H_v + batch*S_v*H_v*n_tokens + float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; + const int64_t out_token_stride = HEAD_DIM * n_heads; // stride between tokens + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory for scalars (moved outside loop for clarity) + __shared__ float shared_g_val, shared_beta_val, shared_decay, shared_attn_score; + + // Shared memory for current token's Q, K, V (normalized), and intermediate results + extern __shared__ float smem[]; + float * sQ = smem; // HEAD_DIM + float * sK = sQ + HEAD_DIM; // HEAD_DIM + float * sV = sK + HEAD_DIM; // HEAD_DIM + float * sKBeta = sV + HEAD_DIM; // HEAD_DIM (plain k for state update) + float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta)) + float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM + float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g)) + float * sVPrime = sKCumdecay + HEAD_DIM; // HEAD_DIM (state @ k_cumdecay) + float * sVNew = sVPrime + HEAD_DIM; // HEAD_DIM (v_beta - v_prime) + float * sNorm = sVNew + HEAD_DIM; // 2 (for Q and K norms) + + const float scale = rsqrtf((float)HEAD_DIM); + + // Copy initial state to output buffer (will be updated in place) + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + int col = i / HEAD_DIM; + int row = i % HEAD_DIM; + // Column-major: state[row, col] at index row + col*HEAD_DIM + state_dst[row + col * HEAD_DIM] = state_src[row + col * HEAD_DIM]; + } + __syncthreads(); + + // Process each token sequentially + for (int64_t t = 0; t < n_tokens; t++) { + // Reset norm accumulators + if (tid < 2) { + sNorm[tid] = 0.0f; + } + __syncthreads(); + + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sQ[i] = q_ptr[t * qkv_stride_token + i]; + sK[i] = k_ptr[t * qkv_stride_token + i]; + sV[i] = v_ptr[t * qkv_stride_token + i]; + } + __syncthreads(); + + float q_sq_local = 0.0f; + float k_sq_local = 0.0f; + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + q_sq_local += sQ[i] * sQ[i]; + k_sq_local += sK[i] * sK[i]; + } + + // Warp reduction + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); + k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); + } + + // Cross-warp reduction using shared memory atomics + if (tid % WARP_SIZE == 0) { + atomicAdd(&sNorm[0], q_sq_local); + atomicAdd(&sNorm[1], k_sq_local); + } + __syncthreads(); + + float q_norm = rsqrtf(sNorm[0] + eps); + float k_norm = rsqrtf(sNorm[1] + eps); + + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sQ[i] = sQ[i] * q_norm * scale; + sK[i] = sK[i] * k_norm; + } + __syncthreads(); + + if (tid == 0) { + shared_g_val = g_ptr[t]; + shared_beta_val = sigmoid_f(beta_ptr[t]); + shared_decay = expf(fminf(shared_g_val, 50.0f)); + } + __syncthreads(); + + float beta_val = shared_beta_val; + float decay = shared_decay; + + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sKBeta[i] = sK[i]; + sVBeta[i] = sV[i] * beta_val; + sKCumdecay[i] = sK[i] * beta_val * decay; + } + __syncthreads(); + + for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { + sum += state_dst[row_out + col * HEAD_DIM] * sKCumdecay[col]; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) { + sVPrime[row_out] = sum; + } + } + __syncthreads(); + + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + if (warp_id == 0) { + float sum = 0.0f; + #pragma unroll 4 + for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { + sum += sK[i] * sQ[i]; + } + // Warp reduction + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) { + shared_attn_score = sum; + } + } + __syncthreads(); + + for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { + float state_val = state_dst[row_out + col * HEAD_DIM]; + sum += sQ[col] * decay * state_val; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) { + float v_attn = sVNew[row_out] * shared_attn_score; + sOut[row_out] = sum + v_attn; + } + } + __syncthreads(); + + for (int out_dim = tid; out_dim < HEAD_DIM; out_dim += blockDim.x) { + for (int row = 0; row < HEAD_DIM; row++) { + float state_val = state_dst[row + out_dim * HEAD_DIM]; + float safe_decay = decay; + if (isnan(safe_decay) || isinf(safe_decay)) { + safe_decay = 1.0f; + } + float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim]; + new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); + state_dst[row + out_dim * HEAD_DIM] = new_state_val; + } + } + __syncthreads(); + + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + out_base[t * out_token_stride + i] = sOut[i]; + } + __syncthreads(); + } +} + +// Generic kernel that handles any HEAD_DIM at runtime (slower but flexible) +__global__ void delta_net_recurrent_generic_f32( + const float * __restrict__ q, + const float * __restrict__ k, + const float * __restrict__ v, + const float * __restrict__ g, + const float * __restrict__ beta_in, + const float * __restrict__ state_in, + float * __restrict__ dst, + const int64_t head_dim, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, + const float eps) +{ + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + + // Strides (column-major) + const int64_t qkv_stride_token = head_dim; + const int64_t qkv_stride_head = head_dim * n_tokens; + const int64_t qkv_stride_batch = head_dim * n_tokens * n_heads; + + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + + const int64_t state_head_offset = head_idx * head_dim * head_dim; + const int64_t state_batch_stride = head_dim * head_dim * n_heads; + + // Pointers + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + + // Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs] + float * out_base = dst + batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim; + const int64_t out_token_stride = head_dim * n_heads; + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory for scalars (outside loop) + __shared__ float shared_g_val, shared_beta_val, shared_decay, shared_attn_score; + + // Dynamic shared memory + extern __shared__ float smem[]; + float * sQ = smem; + float * sK = sQ + head_dim; + float * sV = sK + head_dim; + float * sKBeta = sV + head_dim; // plain k for state update + float * sVBeta = sKBeta + head_dim; // v * sigmoid(beta) + float * sOut = sVBeta + head_dim; + float * sKCumdecay = sOut + head_dim; // k * sigmoid(beta) * exp(g) + float * sVPrime = sKCumdecay + head_dim; // state @ k_cumdecay + float * sVNew = sVPrime + head_dim; // v_beta - v_prime + float * sNorm = sVNew + head_dim; + + const float scale = rsqrtf((float)head_dim); + + // Copy initial state to output buffer + for (int i = tid; i < head_dim * head_dim; i += blockDim.x) { + int col = i / head_dim; + int row = i % head_dim; + state_dst[row + col * head_dim] = state_src[row + col * head_dim]; + } + __syncthreads(); + + // Process each token + for (int64_t t = 0; t < n_tokens; t++) { + if (tid < 2) sNorm[tid] = 0.0f; + __syncthreads(); + + // Load Q, K, V + for (int i = tid; i < head_dim; i += blockDim.x) { + sQ[i] = q_ptr[t * qkv_stride_token + i]; + sK[i] = k_ptr[t * qkv_stride_token + i]; + sV[i] = v_ptr[t * qkv_stride_token + i]; + } + __syncthreads(); + + // L2 normalize Q and K + float q_sq = 0.0f, k_sq = 0.0f; + for (int i = tid; i < head_dim; i += blockDim.x) { + q_sq += sQ[i] * sQ[i]; + k_sq += sK[i] * sK[i]; + } + + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq += __shfl_xor_sync(0xffffffff, q_sq, offset); + k_sq += __shfl_xor_sync(0xffffffff, k_sq, offset); + } + + if (tid % WARP_SIZE == 0) { + atomicAdd(&sNorm[0], q_sq); + atomicAdd(&sNorm[1], k_sq); + } + __syncthreads(); + + float q_norm = rsqrtf(sNorm[0] + eps); + float k_norm = rsqrtf(sNorm[1] + eps); + + for (int i = tid; i < head_dim; i += blockDim.x) { + sQ[i] *= q_norm * scale; + sK[i] *= k_norm; + } + __syncthreads(); + + // Load g and beta, compute decay + if (tid == 0) { + shared_g_val = g_ptr[t]; + shared_beta_val = sigmoid_f(beta_ptr[t]); + shared_decay = expf(fminf(shared_g_val, 50.0f)); + } + __syncthreads(); + + float beta_val = shared_beta_val; + float decay = shared_decay; + + // Compute k_beta, v_beta, k_cumdecay + for (int i = tid; i < head_dim; i += blockDim.x) { + sKBeta[i] = sK[i]; + sVBeta[i] = sV[i] * beta_val; + sKCumdecay[i] = sK[i] * beta_val * decay; + } + __syncthreads(); + + // Compute v_prime = state @ k_cumdecay + for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) { + float v_prime_val = 0.0f; + for (int col = 0; col < head_dim; col++) { + // Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ k + v_prime_val += state_dst[row_out + col * head_dim] * sKCumdecay[col]; + } + sVPrime[row_out] = v_prime_val; + } + __syncthreads(); + + // Compute v_new = v_beta - v_prime (the value residual) + for (int i = tid; i < head_dim; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + // Compute attn_score = dot(k, q) (L2 normalized vectors) + if (tid == 0) { + float dot_sum = 0.0f; + for (int i = 0; i < head_dim; i++) { + dot_sum += sK[i] * sQ[i]; + } + shared_attn_score = dot_sum; + } + __syncthreads(); + + // Compute output: o[t] = attn_inter + v_attn + // attn_inter = state @ (q * exp(g)) = sum_col(state[row_out, col] * q[col] * exp(g)) + // The decomposed path uses: attn_inter = ggml_mul_mat(state_t, q_g_exp) + // Since ggml_mul_mat(A,B) = A^T @ B, attn_inter = state_t^T @ q_g_exp = state @ (q * exp(g)) + for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) { + float attn_inter = 0.0f; + + for (int col = 0; col < head_dim; col++) { + // Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ q + float state_val = state_dst[row_out + col * head_dim]; + attn_inter += sQ[col] * decay * state_val; + } + + // v_attn = v_new * attn_score + float v_attn = sVNew[row_out] * shared_attn_score; + + // Output = attn_inter + v_attn (correct DeltaNet formula) + sOut[row_out] = attn_inter + v_attn; + } + __syncthreads(); + + // Update state: state_new = decay * state + outer(v_new, k) + // Fixed: outer product orientation matches decomposed: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] + // Uses transposed indexing: state_dst[row + out_dim * head_dim] = state[row][out_dim] + // Only protect against NaN/Inf - do NOT clamp decay value + float safe_decay = decay; + if (isnan(safe_decay) || isinf(safe_decay)) { + safe_decay = 1.0f; + } + + for (int out_dim = tid; out_dim < head_dim; out_dim += blockDim.x) { + for (int row = 0; row < head_dim; row++) { + float state_val = state_dst[row + out_dim * head_dim]; + + // state_new[row][out_dim] = decay * state[row][out_dim] + v_new[row] * k[out_dim] + // Fix: outer product matches decomposed path: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] + float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim]; + + // Clamp state to prevent overflow + new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); + state_dst[row + out_dim * head_dim] = new_state_val; + } + } + __syncthreads(); + + // Write output + for (int i = tid; i < head_dim; i += blockDim.x) { + out_base[t * out_token_stride + i] = sOut[i]; + } + __syncthreads(); + } +} + +// FP16 DeltaNet kernel using __hfma2 for 2x throughput +#if !defined(GGML_USE_HIP) +template +__global__ void delta_net_fp16_optimized( + const float * __restrict__ q, + const float * __restrict__ k, + const float * __restrict__ v, + const float * __restrict__ g, + const float * __restrict__ beta_in, + const float * __restrict__ state_in, + float * __restrict__ dst, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, + const float eps) +{ + static_assert(HEAD_DIM == 128, "FP16 kernel requires HEAD_DIM=128"); + static_assert(HEAD_DIM % 2 == 0, "HEAD_DIM must be even for half2"); + + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + constexpr int NUM_WARPS = 8; // 256 threads / 32 + + // Strides (column-major) + const int64_t qkv_stride_token = HEAD_DIM; + const int64_t qkv_stride_head = HEAD_DIM * n_tokens; + const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; + const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + + // Pointers + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; + const int64_t out_token_stride = HEAD_DIM * n_heads; + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory layout: + // - FP16 state COLUMN-MAJOR: 128×128 = 16384 half = 32KB + // - FP16 vectors: K, KCumdecay, Q_scaled = 3 × 128 = 384 half = 768 bytes + // - FP32 vectors: V, KBeta, VBeta, Out, VPrime, VNew = 6 × 128 = 768 floats = 3KB + // Total: ~36KB + + extern __shared__ char smem_raw[]; + + // FP16 state COLUMN-MAJOR: state[row, col] = state_smem[row + col * HEAD_DIM] + half * state_smem = (half *)smem_raw; + + // FP16 vectors + half * sK_fp16 = (half *)(smem_raw + HEAD_DIM * HEAD_DIM * sizeof(half)); + half * sKCumdecay_fp16 = sK_fp16 + HEAD_DIM; + half * sQ_fp16 = sKCumdecay_fp16 + HEAD_DIM; + + // FP32 vectors + float * sV = (float *)(sQ_fp16 + HEAD_DIM); + float * sKBeta = sV + HEAD_DIM; + float * sVBeta = sKBeta + HEAD_DIM; + float * sOut = sVBeta + HEAD_DIM; + float * sVPrime = sOut + HEAD_DIM; + float * sVNew = sVPrime + HEAD_DIM; + float * sNorm = sVNew + HEAD_DIM; + + __shared__ float shared_decay, shared_attn_score; + + const float scale = rsqrtf((float)HEAD_DIM); + + // Load initial state DIRECTLY (no transpose - same layout as global) + // state[row, col] = state_smem[row + col * HEAD_DIM] + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + state_smem[i] = __float2half(state_src[i]); + } + __syncthreads(); + + // Process each token + for (int64_t t = 0; t < n_tokens; t++) { + // Reset norms + if (tid < 2) { + sNorm[tid] = 0.0f; + } + __syncthreads(); + + // 1. Load Q, K, V and compute norms + float q_sq_local = 0.0f, k_sq_local = 0.0f; + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + float q_val = q_ptr[t * qkv_stride_token + i]; + float k_val = k_ptr[t * qkv_stride_token + i]; + sV[i] = v_ptr[t * qkv_stride_token + i]; + q_sq_local += q_val * q_val; + k_sq_local += k_val * k_val; + sVPrime[i] = q_val; // Temp storage for Q + sVNew[i] = k_val; // Temp storage for K + } + + // Warp reduction for norms + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); + k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); + } + if (lane_id == 0) { + atomicAdd(&sNorm[0], q_sq_local); + atomicAdd(&sNorm[1], k_sq_local); + } + __syncthreads(); + + float q_norm = rsqrtf(sNorm[0] + eps); + float k_norm = rsqrtf(sNorm[1] + eps); + + // 2. Load g and beta, compute decay + if (tid == 0) { + shared_decay = expf(fminf(g_ptr[t], 50.0f)); // Clamp g to prevent overflow + } + __syncthreads(); + float decay = shared_decay; + float beta_val = sigmoid_f(beta_ptr[t]); + + // 3. Compute normalized vectors and convert to FP16 + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + float q_normalized = sVPrime[i] * q_norm * scale; + float k_normalized = sVNew[i] * k_norm; + + sQ_fp16[i] = __float2half(q_normalized * decay); + sK_fp16[i] = __float2half(k_normalized); + sKCumdecay_fp16[i] = __float2half(k_normalized * beta_val * decay); + + sKBeta[i] = k_normalized; + sVBeta[i] = sV[i] * beta_val; + } + __syncthreads(); + + // 4. v_prime = state @ k_cumdecay using half2 + // Column-major: state[row, col] = state_smem[row + col * HEAD_DIM] + // v_prime[col] = sum_row(state[row, col] * k_cumdecay[row]) + // For fixed col, state[0,col], state[1,col], ... = state_smem[col*128], state_smem[col*128+1], ... + // These ARE contiguous! Can use half2. + for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { + half2 sum_h2 = __float2half2_rn(0.0f); + half2 * state_col = (half2 *)(&state_smem[col * HEAD_DIM]); + half2 * vec_h2 = (half2 *)sKCumdecay_fp16; + + #pragma unroll 2 + for (int row = lane_id; row < HEAD_DIM / 2; row += WARP_SIZE) { + sum_h2 = __hfma2(state_col[row], vec_h2[row], sum_h2); + } + + float sum = __half2float(sum_h2.x) + __half2float(sum_h2.y); + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + + if (lane_id == 0) { + sVPrime[col] = sum; + } + } + __syncthreads(); + + // 5. v_new = v_beta - v_prime + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + // 6. attn_score = dot(k, q) in FP32 + if (warp_id == 0) { + float sum = 0.0f; + for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { + sum += sKBeta[i] * __half2float(sQ_fp16[i]) / decay; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) { + shared_attn_score = sum; + } + } + __syncthreads(); + + // 7. output = attn_inter + v_attn + // attn_inter[col] = sum_row(state[row, col] * q_scaled[row]) + // Same pattern as v_prime - columns are contiguous! + for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { + half2 sum_h2 = __float2half2_rn(0.0f); + half2 * state_col = (half2 *)(&state_smem[col * HEAD_DIM]); + half2 * vec_h2 = (half2 *)sQ_fp16; + + #pragma unroll 2 + for (int row = lane_id; row < HEAD_DIM / 2; row += WARP_SIZE) { + sum_h2 = __hfma2(state_col[row], vec_h2[row], sum_h2); + } + + float sum = __half2float(sum_h2.x) + __half2float(sum_h2.y); + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + + if (lane_id == 0) { + float v_attn = sVNew[col] * shared_attn_score; + sOut[col] = sum + v_attn; + } + } + __syncthreads(); + + // 8. Update state: state_new = decay * state + outer(k, v_new) + // state[row, col] = decay * state[row, col] + k[row] * v_new[col] + half decay_h = __float2half(fminf(fmaxf(decay, 0.0f), 10.0f)); + + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + int col = i / HEAD_DIM; + int row = i % HEAD_DIM; + + half state_val = state_smem[row + col * HEAD_DIM]; + half k_val = sK_fp16[row]; + half v_new_h = __float2half(sVNew[col]); + + half new_val = __hfma(decay_h, state_val, __hmul(k_val, v_new_h)); + + float new_val_f = __half2float(new_val); + new_val_f = fminf(fmaxf(new_val_f, -1e4f), 1e4f); + state_smem[row + col * HEAD_DIM] = __float2half(new_val_f); + } + __syncthreads(); + + // 9. Write output + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + out_base[t * out_token_stride + i] = sOut[i]; + } + __syncthreads(); + } + + // Write final state DIRECTLY (no transpose needed - same layout) + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + state_dst[i] = __half2float(state_smem[i]); + } +} + +#endif // !defined(GGML_USE_HIP) + +// Blackwell kernel (SM 12.0+): Full 64KB state in shared memory +#if !defined(GGML_USE_HIP) + +template +__global__ __launch_bounds__(256, 1) // 256 threads, 1 block per SM for max shared mem +void delta_net_blackwell_f32( + const float * __restrict__ q, + const float * __restrict__ k, + const float * __restrict__ v, + const float * __restrict__ g, + const float * __restrict__ beta_in, + const float * __restrict__ state_in, + float * __restrict__ dst, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, + const float eps) +{ + static_assert(HEAD_DIM == 128, "Blackwell kernel optimized for HEAD_DIM=128"); + + // One block per (batch, head) - NO column splitting! + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + constexpr int NUM_WARPS = 8; // 256 / 32 + + // Strides (column-major) + const int64_t qkv_stride_token = HEAD_DIM; + const int64_t qkv_stride_head = HEAD_DIM * n_tokens; + const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; + const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + + // Pointers + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; + const int64_t out_token_stride = HEAD_DIM * n_heads; + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory: 64KB state + 4.5KB vectors + scratch + extern __shared__ char smem_raw[]; + float * state_smem = (float *)smem_raw; + float * sQ = (float *)(smem_raw + HEAD_DIM * HEAD_DIM * sizeof(float)); + float * sK = sQ + HEAD_DIM; + float * sV = sK + HEAD_DIM; + float * sKBeta = sV + HEAD_DIM; + float * sVBeta = sKBeta + HEAD_DIM; + float * sKCumdecay = sVBeta + HEAD_DIM; + float * sVPrime = sKCumdecay + HEAD_DIM; + float * sVNew = sVPrime + HEAD_DIM; + float * sOut = sVNew + HEAD_DIM; + + float * warp_scratch = sOut + HEAD_DIM; + __shared__ float shared_decay, shared_beta, shared_attn_score, shared_q_norm, shared_k_norm; + const float scale = rsqrtf((float)HEAD_DIM); + + // Load state (transposed for coalesced access) + #pragma unroll 8 + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + int col = i / HEAD_DIM, row = i % HEAD_DIM; + state_smem[row + col * HEAD_DIM] = state_src[col + row * HEAD_DIM]; + } + __syncthreads(); + + for (int64_t t = 0; t < n_tokens; t++) { + // Load Q, K, V and compute norms + float q_sq_local = 0.0f, k_sq_local = 0.0f; + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + float q_val = q_ptr[t * qkv_stride_token + i]; + float k_val = k_ptr[t * qkv_stride_token + i]; + sQ[i] = q_val; + sK[i] = k_val; + sV[i] = v_ptr[t * qkv_stride_token + i]; + q_sq_local += q_val * q_val; + k_sq_local += k_val * k_val; + } + + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); + k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); + } + + if (lane_id == 0) { + warp_scratch[warp_id * 2] = q_sq_local; + warp_scratch[warp_id * 2 + 1] = k_sq_local; + } + __syncthreads(); + + if (tid == 0) { + float total_q = 0.0f, total_k = 0.0f; + #pragma unroll + for (int w = 0; w < NUM_WARPS; w++) { + total_q += warp_scratch[w * 2]; + total_k += warp_scratch[w * 2 + 1]; + } + shared_q_norm = rsqrtf(total_q + eps); + shared_k_norm = rsqrtf(total_k + eps); + shared_decay = expf(fminf(g_ptr[t], 50.0f)); + shared_beta = sigmoid_f(beta_ptr[t]); + } + __syncthreads(); + + float q_norm = shared_q_norm; + float k_norm = shared_k_norm; + float decay = shared_decay; + float beta_val = shared_beta; + + // Normalize and prepare vectors + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sQ[i] = sQ[i] * q_norm * scale; + sK[i] = sK[i] * k_norm; + sKBeta[i] = sK[i]; + sVBeta[i] = sV[i] * beta_val; + sKCumdecay[i] = sK[i] * beta_val * decay; + } + __syncthreads(); + + // v_prime = state @ k_cumdecay + for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { + sum += state_smem[row + col * HEAD_DIM] * sKCumdecay[row]; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) sVPrime[col] = sum; + } + __syncthreads(); + + // v_new = v_beta - v_prime + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + // attn_score = dot(K, Q) + if (warp_id == 0) { + float sum = 0.0f; + #pragma unroll 4 + for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { + sum += sK[i] * sQ[i]; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) shared_attn_score = sum; + } + __syncthreads(); + + float attn_score = shared_attn_score; + + // output = (state @ q*decay) + v_new * attn_score + for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { + sum += state_smem[row + col * HEAD_DIM] * sQ[row] * decay; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) sOut[col] = sum + sVNew[col] * attn_score; + } + __syncthreads(); + + // Update state: state_new = decay * state + outer(v_new, k) + float safe_decay = (isnan(decay) || isinf(decay)) ? 1.0f : decay; + for (int col = tid; col < HEAD_DIM; col += blockDim.x) { + float v_col = sVNew[col]; + for (int row = 0; row < HEAD_DIM; row++) { + float old_state = state_smem[row + col * HEAD_DIM]; + float new_state = safe_decay * old_state + v_col * sKBeta[row]; + state_smem[row + col * HEAD_DIM] = fminf(fmaxf(new_state, -1e6f), 1e6f); + } + } + __syncthreads(); + + // Write output + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + out_base[t * out_token_stride + i] = sOut[i]; + } + __syncthreads(); + } + + // Write final state (transpose back) + #pragma unroll 8 + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + int col = i / HEAD_DIM, row = i % HEAD_DIM; + state_dst[col + row * HEAD_DIM] = state_smem[row + col * HEAD_DIM]; + } +} + +// Blackwell V2: Bank-conflict-free with padded layout (128→132) +template +__global__ __launch_bounds__(256, 1) +void delta_net_blackwell_optimized_f32( + const float * __restrict__ q, + const float * __restrict__ k, + const float * __restrict__ v, + const float * __restrict__ g, + const float * __restrict__ beta_in, + const float * __restrict__ state_in, + float * __restrict__ dst, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, + const float eps) +{ + static_assert(HEAD_DIM == 128, "Optimized kernel for HEAD_DIM=128"); + constexpr int PADDED_DIM = HEAD_DIM + 4; // Bank conflict elimination + + const int batch_idx = blockIdx.x / n_heads; + const int head_idx = blockIdx.x % n_heads; + const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + constexpr int NUM_WARPS = 8; + + const int64_t qkv_stride_token = HEAD_DIM; + const int64_t qkv_stride_head = HEAD_DIM * n_tokens; + const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; + const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; + const int64_t out_token_stride = HEAD_DIM * n_heads; + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory: 67.5KB padded state + 4.5KB vectors + extern __shared__ char smem_raw[]; + float * state_smem = (float *)smem_raw; + float * sQ = (float *)(smem_raw + HEAD_DIM * PADDED_DIM * sizeof(float)); + float * sK = sQ + HEAD_DIM; + float * sV = sK + HEAD_DIM; + float * sKBeta = sV + HEAD_DIM; + float * sVBeta = sKBeta + HEAD_DIM; + float * sKCumdecay = sVBeta + HEAD_DIM; + float * sVPrime = sKCumdecay + HEAD_DIM; + float * sVNew = sVPrime + HEAD_DIM; + float * sOut = sVNew + HEAD_DIM; + float * warp_scratch = sOut + HEAD_DIM; + __shared__ float shared_decay, shared_beta, shared_attn_score, shared_q_norm, shared_k_norm; + const float scale = rsqrtf((float)HEAD_DIM); + + // Load state with padding + #pragma unroll 8 + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + int col = i / HEAD_DIM, row = i % HEAD_DIM; + state_smem[row + col * PADDED_DIM] = state_src[row + col * HEAD_DIM]; + } + __syncthreads(); + + for (int64_t t = 0; t < n_tokens; t++) { + // Load Q, K, V (vectorized) + float q_sq_local = 0.0f, k_sq_local = 0.0f; + const float4 * q_ptr_v = (const float4 *)(q_ptr + t * qkv_stride_token); + const float4 * k_ptr_v = (const float4 *)(k_ptr + t * qkv_stride_token); + const float4 * v_ptr_v = (const float4 *)(v_ptr + t * qkv_stride_token); + + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM / 4; i += blockDim.x) { + float4 q_val = q_ptr_v[i]; + float4 k_val = k_ptr_v[i]; + float4 v_val = v_ptr_v[i]; + int base = i * 4; + sQ[base] = q_val.x; sQ[base+1] = q_val.y; sQ[base+2] = q_val.z; sQ[base+3] = q_val.w; + sK[base] = k_val.x; sK[base+1] = k_val.y; sK[base+2] = k_val.z; sK[base+3] = k_val.w; + sV[base] = v_val.x; sV[base+1] = v_val.y; sV[base+2] = v_val.z; sV[base+3] = v_val.w; + q_sq_local += q_val.x*q_val.x + q_val.y*q_val.y + q_val.z*q_val.z + q_val.w*q_val.w; + k_sq_local += k_val.x*k_val.x + k_val.y*k_val.y + k_val.z*k_val.z + k_val.w*k_val.w; + } + + // Warp reduction for norms + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); + k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); + } + + // Cross-warp reduction using shared memory + if (lane_id == 0) { + warp_scratch[warp_id * 2] = q_sq_local; + warp_scratch[warp_id * 2 + 1] = k_sq_local; + } + __syncthreads(); + + if (tid == 0) { + float total_q = 0.0f, total_k = 0.0f; + #pragma unroll + for (int w = 0; w < NUM_WARPS; w++) { + total_q += warp_scratch[w * 2]; + total_k += warp_scratch[w * 2 + 1]; + } + shared_q_norm = rsqrtf(total_q + eps); + shared_k_norm = rsqrtf(total_k + eps); + shared_decay = expf(fminf(g_ptr[t], 50.0f)); + shared_beta = sigmoid_f(beta_ptr[t]); + } + __syncthreads(); + + float q_norm = shared_q_norm, k_norm = shared_k_norm; + float decay = shared_decay, beta_val = shared_beta; + + // Normalize vectors + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sQ[i] = sQ[i] * q_norm * scale; + sK[i] = sK[i] * k_norm; + sKBeta[i] = sK[i]; + sVBeta[i] = sV[i] * beta_val; + sKCumdecay[i] = sK[i] * beta_val * decay; + } + __syncthreads(); + + // v_prime = state @ k_cumdecay + for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { + sum += state_smem[row_out + col * PADDED_DIM] * sKCumdecay[col]; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) sVPrime[row_out] = sum; + } + __syncthreads(); + + // v_new = v_beta - v_prime + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + // attn_score = dot(K, Q) + if (warp_id == 0) { + float sum = 0.0f; + #pragma unroll 4 + for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { + sum += sK[i] * sQ[i]; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) shared_attn_score = sum; + } + __syncthreads(); + + float attn_score = shared_attn_score; + + // output = (state @ q*decay) + v_new * attn_score + for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { + sum += state_smem[row_out + col * PADDED_DIM] * sQ[col] * decay; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) sOut[row_out] = sum + sVNew[row_out] * attn_score; + } + __syncthreads(); + + // State update + #pragma unroll 4 + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + int col = i / HEAD_DIM, row = i % HEAD_DIM; + float old_state = state_smem[row + col * PADDED_DIM]; + float new_state = decay * old_state + sKBeta[row] * sVNew[col]; + state_smem[row + col * PADDED_DIM] = fminf(fmaxf(new_state, -1e6f), 1e6f); + } + __syncthreads(); + + // Write output (vectorized) + float4 * out_ptr_v = (float4 *)(out_base + t * out_token_stride); + #pragma unroll 2 + for (int i = tid; i < HEAD_DIM / 4; i += blockDim.x) { + int base = i * 4; + float4 out_val = {sOut[base], sOut[base+1], sOut[base+2], sOut[base+3]}; + out_ptr_v[i] = out_val; + } + __syncthreads(); + } + + // Write final state (remove padding) + #pragma unroll 8 + for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { + int col = i / HEAD_DIM, row = i % HEAD_DIM; + state_dst[row + col * HEAD_DIM] = state_smem[row + col * PADDED_DIM]; + } +} + +#endif // !defined(GGML_USE_HIP) + +// Multi-block column-parallel kernel (pre-Blackwell fallback) +// Each block handles COLS_PER_BLOCK columns of the 128x128 state +// With COLS_PER_BLOCK=16: 128/16 = 8 blocks per head, 16 heads = 128 blocks total +// State tile per block: 128 rows × 16 cols = 2048 floats = 8KB (fits in shared memory!) +template +__global__ void delta_net_multiblock_f32( + const float * __restrict__ q, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ k, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ v, // [HEAD_DIM, n_tokens, n_heads, n_seqs] + const float * __restrict__ g, // [n_tokens, 1, n_heads, n_seqs] + const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs] + const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] + float * __restrict__ dst, // output + new_state concatenated + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const int64_t output_offset, + const float eps) +{ + static_assert(HEAD_DIM % COLS_PER_BLOCK == 0, "HEAD_DIM must be divisible by COLS_PER_BLOCK"); + constexpr int NUM_COL_GROUPS = HEAD_DIM / COLS_PER_BLOCK; + + // Decode block index: (batch_idx, head_idx, col_group) + const int blocks_per_seq = n_heads * NUM_COL_GROUPS; + const int batch_idx = blockIdx.x / blocks_per_seq; + const int remaining = blockIdx.x % blocks_per_seq; + const int head_idx = remaining / NUM_COL_GROUPS; + const int col_group = remaining % NUM_COL_GROUPS; + const int col_start = col_group * COLS_PER_BLOCK; + + const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + constexpr int NUM_WARPS = 8; + + // Strides (column-major) + const int64_t qkv_stride_token = HEAD_DIM; + const int64_t qkv_stride_head = HEAD_DIM * n_tokens; + const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; + const int64_t g_stride_head = n_tokens; + const int64_t g_stride_batch = n_tokens * n_heads; + const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; + const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; + + // Pointers + const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; + const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; + const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; + + float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; + const int64_t out_token_stride = HEAD_DIM * n_heads; + float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; + + // Shared memory layout: + // - State tile: HEAD_DIM × COLS_PER_BLOCK = 128 × 16 = 2048 floats = 8KB + // - Full vectors: K, KCumdecay, Q (need all HEAD_DIM elements) = 3 × 128 = 1.5KB + // - Local vectors: V, VBeta, VPrime, VNew, Out (only COLS_PER_BLOCK) = 5 × 16 = 320 bytes + // - Norms: 2 floats + // Total: ~10KB (excellent for occupancy!) + extern __shared__ float smem[]; + + // State tile in shared memory: state_tile[row + local_col * HEAD_DIM] + // local_col ∈ [0, COLS_PER_BLOCK), global_col = col_start + local_col + float * state_tile = smem; // HEAD_DIM * COLS_PER_BLOCK + + // Full vectors (need all HEAD_DIM for matrix-vector and dot products) + float * sK = state_tile + HEAD_DIM * COLS_PER_BLOCK; // HEAD_DIM + float * sKCumdecay = sK + HEAD_DIM; // HEAD_DIM + float * sQ = sKCumdecay + HEAD_DIM; // HEAD_DIM + + // Local vectors (only need COLS_PER_BLOCK elements) + float * sV = sQ + HEAD_DIM; // COLS_PER_BLOCK + float * sVBeta = sV + COLS_PER_BLOCK; // COLS_PER_BLOCK + float * sVPrime = sVBeta + COLS_PER_BLOCK; // COLS_PER_BLOCK + float * sVNew = sVPrime + COLS_PER_BLOCK; // COLS_PER_BLOCK + float * sOut = sVNew + COLS_PER_BLOCK; // COLS_PER_BLOCK + float * sNorm = sOut + COLS_PER_BLOCK; // 2 + + __shared__ float shared_decay, shared_beta, shared_attn_score; + + const float scale = rsqrtf((float)HEAD_DIM); + + // Load initial state tile from global to shared + // state_tile[row + local_col * HEAD_DIM] = state[row, col_start + local_col] + for (int i = tid; i < HEAD_DIM * COLS_PER_BLOCK; i += blockDim.x) { + int row = i % HEAD_DIM; + int local_col = i / HEAD_DIM; + int global_col = col_start + local_col; + state_tile[row + local_col * HEAD_DIM] = state_src[row + global_col * HEAD_DIM]; + } + __syncthreads(); + + // Process each token + for (int64_t t = 0; t < n_tokens; t++) { + // Reset norms + if (tid < 2) { + sNorm[tid] = 0.0f; + } + __syncthreads(); + + // 1. Load full K, Q (all HEAD_DIM elements - needed for matrix-vector and attn_score) + float q_sq_local = 0.0f, k_sq_local = 0.0f; + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + float q_val = q_ptr[t * qkv_stride_token + i]; + float k_val = k_ptr[t * qkv_stride_token + i]; + sQ[i] = q_val; + sK[i] = k_val; + q_sq_local += q_val * q_val; + k_sq_local += k_val * k_val; + } + + // Load V for our columns only + for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { + sV[i] = v_ptr[t * qkv_stride_token + col_start + i]; + } + + // Warp reduction for norms + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); + k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); + } + if (lane_id == 0) { + atomicAdd(&sNorm[0], q_sq_local); + atomicAdd(&sNorm[1], k_sq_local); + } + __syncthreads(); + + float q_norm = rsqrtf(sNorm[0] + eps); + float k_norm = rsqrtf(sNorm[1] + eps); + + // 2. Load g, beta and normalize vectors + if (tid == 0) { + shared_decay = expf(fminf(g_ptr[t], 50.0f)); // Clamp g to prevent overflow + shared_beta = sigmoid_f(beta_ptr[t]); + } + __syncthreads(); + + float decay = shared_decay; + float beta_val = shared_beta; + + // Normalize and compute KCumdecay + for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + sQ[i] = sQ[i] * q_norm * scale; + sK[i] = sK[i] * k_norm; + sKCumdecay[i] = sK[i] * beta_val * decay; + } + + // Compute VBeta for our columns + for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { + sVBeta[i] = sV[i] * beta_val; + } + __syncthreads(); + + // 3. Compute v_prime for our columns: v_prime[local_col] = sum_row(state_tile[row, local_col] * k_cumdecay[row]) + // Each warp handles one local column + for (int local_col = warp_id; local_col < COLS_PER_BLOCK; local_col += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { + sum += state_tile[row + local_col * HEAD_DIM] * sKCumdecay[row]; + } + // Warp reduction + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) { + sVPrime[local_col] = sum; + } + } + __syncthreads(); + + // 4. Compute v_new for our columns + for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { + sVNew[i] = sVBeta[i] - sVPrime[i]; + } + __syncthreads(); + + // 5. Compute attn_score = dot(k, q) - all blocks compute this redundantly + if (warp_id == 0) { + float sum = 0.0f; + #pragma unroll 4 + for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { + sum += sK[i] * sQ[i]; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) { + shared_attn_score = sum; + } + } + __syncthreads(); + + // 6. Compute output for our columns: out[local_col] = attn_inter + v_attn + // attn_inter[local_col] = sum_row(state_tile[row, local_col] * q_scaled[row]) + for (int local_col = warp_id; local_col < COLS_PER_BLOCK; local_col += NUM_WARPS) { + float sum = 0.0f; + #pragma unroll 4 + for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { + sum += state_tile[row + local_col * HEAD_DIM] * sQ[row] * decay; + } + #pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { + sum += __shfl_xor_sync(0xffffffff, sum, offset); + } + if (lane_id == 0) { + float v_attn = sVNew[local_col] * shared_attn_score; + sOut[local_col] = sum + v_attn; + } + } + __syncthreads(); + + // 7. Update state tile: state_new[row, local_col] = decay * state[row, local_col] + v_new[row] * k[local_col] + // Fixed: outer product orientation matches decomposed: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] + float safe_decay = fminf(fmaxf(decay, 0.0f), 10.0f); + for (int i = tid; i < HEAD_DIM * COLS_PER_BLOCK; i += blockDim.x) { + int row = i % HEAD_DIM; + int local_col = i / HEAD_DIM; + + float state_val = state_tile[row + local_col * HEAD_DIM]; + // Fix: v_new[row=v_idx] * k[local_col=k_idx] to match decomposed + float new_val = safe_decay * state_val + sVNew[row] * sK[local_col]; + new_val = fminf(fmaxf(new_val, -1e6f), 1e6f); + state_tile[row + local_col * HEAD_DIM] = new_val; + } + __syncthreads(); + + // 8. Write output for our columns + for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { + int global_col = col_start + i; + out_base[t * out_token_stride + global_col] = sOut[i]; + } + __syncthreads(); + } + + // Write final state tile back to global + for (int i = tid; i < HEAD_DIM * COLS_PER_BLOCK; i += blockDim.x) { + int row = i % HEAD_DIM; + int local_col = i / HEAD_DIM; + int global_col = col_start + local_col; + state_dst[row + global_col * HEAD_DIM] = state_tile[row + local_col * HEAD_DIM]; + } +} + +// Dispatch function +// device_id and cc (compute capability) are passed from caller to avoid CUDA runtime API calls +static void delta_net_f32_cuda( + const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * state_in, + float * dst, + const int64_t head_dim, + const int64_t n_tokens, + const int64_t n_heads, + const int64_t n_seqs, + const float eps, + const int device_id, + const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0) + cudaStream_t stream) +{ + const int64_t output_offset = head_dim * n_tokens * n_heads * n_seqs; + + // One block per (batch, head) pair + const int num_blocks = n_seqs * n_heads; + const int threads_per_block = 256; + + // Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew) + // Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score + const size_t smem_size = (9 * head_dim + 6) * sizeof(float); + + // Use templated kernel for common head dimensions, generic for others + if (head_dim == 64) { + delta_net_recurrent_f32<64><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + } else if (head_dim == 128) { +#if !defined(GGML_USE_HIP) + // Check for Blackwell (SM 12.0+) which has 228KB shared memory + // cc is in format MAJOR*100 + MINOR*10 (e.g., 890 for 8.9, 1200 for 12.0) + const int sm_major = cc / 100; + + if (sm_major >= 12) { + // Blackwell path: single block per head with FULL state in shared memory + const int blackwell_num_blocks = n_seqs * n_heads; + const int blackwell_threads = 256; + + // Shared memory calculation with explicit breakdown: + // - State matrix: HEAD_DIM × HEAD_DIM × sizeof(float) = 128×128×4 = 65536 bytes (64KB) + // - Vectors (Q,K,V,KBeta,VBeta,KCumdecay,VPrime,VNew,Out): 9 × HEAD_DIM × sizeof(float) = 4608 bytes + // - Warp scratch: 16 × sizeof(float) = 64 bytes + // Total: 65536 + 4608 + 64 = 70208 bytes (~68.6KB) + // Note: __shared__ scalars (decay, beta, etc.) are static, not dynamic + constexpr size_t state_bytes = 128 * 128 * sizeof(float); // 64KB + constexpr size_t vector_bytes = 9 * 128 * sizeof(float); // 4.5KB + constexpr size_t warp_scratch_bytes = 16 * sizeof(float); // 64B + constexpr size_t blackwell_smem_size = state_bytes + vector_bytes + warp_scratch_bytes; + + // Sanity check: ensure we allocated enough + static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch"); + + // Check for A/B comparison mode + // Use a function-local static for thread-safe lazy initialization + static const bool ab_mode = []() { + const char* env = std::getenv("GGML_CUDA_DELTA_NET_AB"); + if (env != nullptr) { + fprintf(stderr, "[DELTA_NET] A/B comparison mode ENABLED\n"); + return true; + } + return false; + }(); + + if (ab_mode) { + // A/B mode: run both kernels and compare outputs + const int64_t total_output_size = output_offset + head_dim * head_dim * n_heads * n_seqs; + + // Allocate temp buffer for recurrent kernel output + float * temp_dst = nullptr; + CUDA_CHECK(cudaMallocAsync(&temp_dst, total_output_size * sizeof(float), stream)); + + // Run recurrent kernel (reference) to temp buffer + delta_net_recurrent_f32<128><<>>( + q, k, v, g, beta, state_in, temp_dst, n_tokens, n_heads, n_seqs, output_offset, eps); + + // Request extended shared memory for Blackwell + CUDA_CHECK(cudaFuncSetAttribute( + delta_net_blackwell_f32<128>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + blackwell_smem_size)); + + // Run Blackwell kernel to dst + delta_net_blackwell_f32<128><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + + // Sync to ensure both kernels complete + CUDA_CHECK(cudaStreamSynchronize(stream)); + + // Copy results back to host for comparison + const int64_t output_elements = head_dim * n_tokens * n_heads * n_seqs; + const int64_t state_elements = head_dim * head_dim * n_heads * n_seqs; + + std::vector ref_output(output_elements); + std::vector ref_state(state_elements); + std::vector bw_output(output_elements); + std::vector bw_state(state_elements); + + CUDA_CHECK(cudaMemcpy(ref_output.data(), temp_dst, output_elements * sizeof(float), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(ref_state.data(), temp_dst + output_offset, state_elements * sizeof(float), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(bw_output.data(), dst, output_elements * sizeof(float), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(bw_state.data(), dst + output_offset, state_elements * sizeof(float), cudaMemcpyDeviceToHost)); + + // Compare outputs + float max_out_diff = 0.0f; + int64_t max_out_idx = 0; + for (int64_t i = 0; i < output_elements; i++) { + float diff = fabsf(ref_output[i] - bw_output[i]); + if (diff > max_out_diff) { + max_out_diff = diff; + max_out_idx = i; + } + } + + // Compare states + float max_state_diff = 0.0f; + int64_t max_state_idx = 0; + for (int64_t i = 0; i < state_elements; i++) { + float diff = fabsf(ref_state[i] - bw_state[i]); + if (diff > max_state_diff) { + max_state_diff = diff; + max_state_idx = i; + } + } + + // Report results + static int ab_call_count = 0; + ab_call_count++; + fprintf(stderr, "[DELTA_NET A/B #%d] n_tokens=%lld output_diff=%e (idx=%lld ref=%e bw=%e) state_diff=%e (idx=%lld ref=%e bw=%e)\n", + ab_call_count, + (long long)n_tokens, + max_out_diff, (long long)max_out_idx, ref_output[max_out_idx], bw_output[max_out_idx], + max_state_diff, (long long)max_state_idx, ref_state[max_state_idx], bw_state[max_state_idx]); + + // Report first 4 output values for head 0 + if (ab_call_count <= 10) { + fprintf(stderr, " ref_out[0:3]=[%e,%e,%e,%e] bw_out[0:3]=[%e,%e,%e,%e]\n", + ref_output[0], ref_output[1], ref_output[2], ref_output[3], + bw_output[0], bw_output[1], bw_output[2], bw_output[3]); + fprintf(stderr, " ref_state[0,1,128,129]=[%e,%e,%e,%e] bw_state=[%e,%e,%e,%e]\n", + ref_state[0], ref_state[1], ref_state[128], ref_state[129], + bw_state[0], bw_state[1], bw_state[128], bw_state[129]); + } + + CUDA_CHECK(cudaFreeAsync(temp_dst, stream)); + } else { + // Normal mode: just run Blackwell kernel + // Request extended shared memory for Blackwell + CUDA_CHECK(cudaFuncSetAttribute( + delta_net_blackwell_f32<128>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + blackwell_smem_size)); + + delta_net_blackwell_f32<128><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + } + } else +#endif // !defined(GGML_USE_HIP) + { + // Pre-Blackwell path: Use recurrent kernel + delta_net_recurrent_f32<128><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + } + } else { + delta_net_recurrent_generic_f32<<>>( + q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps); + } + + // Check for errors (but don't sync during graph capture) + CUDA_CHECK(cudaGetLastError()); + +#ifdef GGML_CUDA_DEBUG_SYNC + // Only sync when not capturing CUDA graphs + cudaStreamCaptureStatus capture_status; + CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status)); + if (capture_status == cudaStreamCaptureStatusNone) { + CUDA_CHECK(cudaDeviceSynchronize()); + } +#endif +} + +void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // q + const ggml_tensor * src1 = dst->src[1]; // k + const ggml_tensor * src2 = dst->src[2]; // v + const ggml_tensor * src3 = dst->src[3]; // g + const ggml_tensor * src4 = dst->src[4]; // beta + const ggml_tensor * src5 = dst->src[5]; // state + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t head_dim = src0->ne[0]; + const int64_t n_tokens = src0->ne[1]; + const int64_t n_heads = src0->ne[2]; + const int64_t n_seqs = src0->ne[3]; + + // Dimension validation + // Q/K: [head_dim, n_tokens, n_heads, n_seqs] + GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads && src1->ne[3] == n_seqs); + // V: [head_dim, n_tokens, n_heads, n_seqs] + GGML_ASSERT(src2->ne[0] == head_dim && src2->ne[1] == n_tokens && src2->ne[2] == n_heads && src2->ne[3] == n_seqs); + // G: [n_tokens, 1, n_heads, n_seqs] + GGML_ASSERT(src3->ne[0] == n_tokens && src3->ne[1] == 1 && src3->ne[2] == n_heads && src3->ne[3] == n_seqs); + // Beta: [1, n_tokens, n_heads, n_seqs] + GGML_ASSERT(src4->ne[0] == 1 && src4->ne[1] == n_tokens && src4->ne[2] == n_heads && src4->ne[3] == n_seqs); + // State: [head_dim, head_dim*n_heads, 1, n_seqs] + GGML_ASSERT(src5->ne[0] == head_dim && src5->ne[1] == head_dim * n_heads && src5->ne[2] == 1 && src5->ne[3] == n_seqs); + + // Verify output tensor size + const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs; + const int64_t state_size = head_dim * head_dim * n_heads * n_seqs; + GGML_ASSERT(ggml_nelements(dst) == output_size + state_size); + + const float eps = 1e-6f; + + GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory + + // Get device info from ctx (avoids calling CUDA runtime APIs inside dispatch) + const int device_id = ctx.device; + const int cc = ggml_cuda_info().devices[device_id].cc; + + delta_net_f32_cuda( + (const float *)src0->data, + (const float *)src1->data, + (const float *)src2->data, + (const float *)src3->data, + (const float *)src4->data, + (const float *)src5->data, + (float *)dst->data, + head_dim, n_tokens, n_heads, n_seqs, eps, + device_id, cc, + ctx.stream()); + +} diff --git a/ggml/src/ggml-cuda/delta-net.cuh b/ggml/src/ggml-cuda/delta-net.cuh new file mode 100644 index 0000000000..a9b223c664 --- /dev/null +++ b/ggml/src/ggml-cuda/delta-net.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ab0f6fe9ce..c2d2e2b927 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -55,6 +55,7 @@ #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/delta-net.cuh" #include "ggml-cuda/tri.cuh" #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" @@ -2735,6 +2736,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; + case GGML_OP_DELTA_NET: + ggml_cuda_op_delta_net(ctx, dst); + break; case GGML_OP_FILL: ggml_cuda_op_fill(ctx, dst); break; @@ -2904,6 +2908,13 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, #endif } + if (node->op == GGML_OP_DELTA_NET) { + use_cuda_graph = false; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to DELTA_NET recurrent state\n", __func__); +#endif + } + if (!use_cuda_graph) { break; } @@ -4632,6 +4643,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_DIAG: case GGML_OP_SOLVE_TRI: return true; + case GGML_OP_DELTA_NET: + return op->src[0]->ne[0] <= 256 && op->src[2]->ne[0] <= 256; default: return false; diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 177ffc268f..0e92ba5012 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -1,86 +1,533 @@ #include "common.cuh" #include "ggml.h" #include "solve_tri.cuh" +#include "ggml-cuda.h" +#include #define MAX_N_FAST 64 -#define MAX_K_FAST 32 +#define MAX_K_FAST 64 -static __global__ void get_batch_pointers(const float * A, - float * X, - const float ** A_ptrs, - float ** X_ptrs, - int64_t ne02, - int64_t total_batches, - size_t s02, - size_t s03, - size_t s2, - size_t s3) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_batches) { - return; - } +// Kernel to set up pointer arrays for batched cuBLAS TRSM +// This avoids host-device copy during CUDA graph capture +static __global__ void setup_trsm_batch_pointers( + const float * A, + float * X, + const float ** A_ptrs, + float ** X_ptrs, + const int64_t ne02, + const int64_t total_batches, + const size_t nb02, // stride for A dim 2 (in floats) + const size_t nb03, // stride for A dim 3 (in floats) + const size_t nb2, // stride for X dim 2 (in floats) + const size_t nb3 // stride for X dim 3 (in floats) +) { + const int64_t batch_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (batch_idx >= total_batches) return; - const int64_t i3 = idx / ne02; - const int64_t i2 = idx % ne02; + // Decompose batch_idx into i02, i03 + const int64_t i02 = batch_idx % ne02; + const int64_t i03 = batch_idx / ne02; - A_ptrs[idx] = A + i3 * s03 + i2 * s02; - X_ptrs[idx] = X + i3 * s3 + i2 * s2; + A_ptrs[batch_idx] = A + i02 * nb02 + i03 * nb03; + X_ptrs[batch_idx] = X + i02 * nb2 + i03 * nb3; } -static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, - const float * A, - const float * B, - float * X, - int n, - int k, - int64_t ne02, - int64_t ne03, - size_t s02, - size_t s03, - size_t s12, - size_t s13, - size_t s2, - size_t s3, - cudaStream_t stream) { - const float alpha = 1.0f; - const int64_t total_batches = ne02 * ne03; - if (total_batches == 0) { - return; +// Latency-optimized kernel for n=64, k=64 (single-token generation) +static __global__ void solve_tri_f32_64x64_latency( + const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3) +{ + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Shared memory: A is 64x64, X is 64x65 (padded for bank conflicts) + __shared__ float sA[64 * 64]; + __shared__ float sX[64 * 65]; + __shared__ float sDiagInv[64]; // Precomputed 1/diagonal + + const int tid = lane + warp_id * WARP_SIZE; + + // Cooperative load of A matrix (4096 elements / 512 threads = 8 per thread) + #pragma unroll 8 + for (int i = tid; i < 64 * 64; i += 512) { + sA[i] = A_batch[i]; } - // Bulk copy B -> X (contiguous tensors) - if (X != B) { - const int64_t total_elements_BX = n * k * total_batches; - CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + // Cooperative load of B matrix into sX with padding + #pragma unroll 8 + for (int i = tid; i < 64 * 64; i += 512) { + const int row = i / 64; + const int col = i % 64; + sX[row * 65 + col] = B_batch[i]; } - const int id = ggml_cuda_get_device(); + __syncthreads(); - ggml_cuda_pool_alloc A_ptrs_alloc(ctx.pool(id), total_batches); - ggml_cuda_pool_alloc X_ptrs_alloc(ctx.pool(id), total_batches); + // Precompute diagonal inverses (first 2 warps handle this) + if (warp_id == 0) { + if (lane < 32) { + sDiagInv[lane] = 1.0f / sA[lane * 64 + lane]; + } + } + if (warp_id == 1) { + if (lane < 32) { + sDiagInv[32 + lane] = 1.0f / sA[(32 + lane) * 64 + (32 + lane)]; + } + } - const float ** A_ptrs_dev = A_ptrs_alloc.get(); - float ** X_ptrs_dev = X_ptrs_alloc.get(); + __syncthreads(); - get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02, - total_batches, s02, s03, s2, s3); + // Each warp handles 4 columns: cols = warp_id*4 to warp_id*4+3 + const int col_base = warp_id * 4; - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + #pragma unroll 1 + for (int row = 0; row < 64; ++row) { + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f; - // Yes, this is necessary, without this we get RMSE errors - CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH)); - CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, - CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches)); + if (row > 0) { + for (int j = lane; j < row; j += WARP_SIZE) { + const float a_val = sA[row * 64 + j]; + sum0 += a_val * sX[j * 65 + col_base + 0]; + sum1 += a_val * sX[j * 65 + col_base + 1]; + sum2 += a_val * sX[j * 65 + col_base + 2]; + sum3 += a_val * sX[j * 65 + col_base + 3]; + } + } - // revert to standard mode from common.cuh - CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH)); + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + sum2 = warp_reduce_sum(sum2); + sum3 = warp_reduce_sum(sum3); - GGML_UNUSED_VARS(s12, s13); + if (lane == 0) { + const float inv_diag = sDiagInv[row]; + sX[row * 65 + col_base + 0] = (sX[row * 65 + col_base + 0] - sum0) * inv_diag; + sX[row * 65 + col_base + 1] = (sX[row * 65 + col_base + 1] - sum1) * inv_diag; + sX[row * 65 + col_base + 2] = (sX[row * 65 + col_base + 2] - sum2) * inv_diag; + sX[row * 65 + col_base + 3] = (sX[row * 65 + col_base + 3] - sum3) * inv_diag; + } + + __syncthreads(); + } + + // Cooperative write results back + #pragma unroll 8 + for (int i = tid; i < 64 * 64; i += 512) { + const int row = i / 64; + const int col = i % 64; + X_batch[i] = sX[row * 65 + col]; + } +} + +static __global__ void solve_tri_f32_64x64_opt(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3) { + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Shared memory: A is 64x64, sXt is 64x65 (padded) + __shared__ float sA[64 * 64]; + __shared__ float sXt[64 * 65]; + + const int tid = lane + warp_id * WARP_SIZE; + + // Cooperative load of A matrix (4096 elements / 1024 threads = 4 per thread) + #pragma unroll 4 + for (int i = tid; i < 64 * 64; i += 1024) { + sA[i] = A_batch[i]; + } + + // Cooperative load of B matrix transposed into sXt + // sXt[col * 65 + row] = B[row * 64 + col] + #pragma unroll 4 + for (int i = tid; i < 64 * 64; i += 1024) { + const int row = i / 64; + const int col = i % 64; + sXt[col * 65 + row] = B_batch[row * 64 + col]; + } + + __syncthreads(); + + // Each warp handles 2 columns: col0 = warp_id*2, col1 = warp_id*2 + 1 + const int col0 = warp_id * 2; + const int col1 = warp_id * 2 + 1; + + // Forward substitution with all columns processed in parallel + // Each row depends on previous rows, but different columns are independent + #pragma unroll 1 + for (int row = 0; row < 64; ++row) { + // Each lane computes partial sum for indices it handles + float sum0 = 0.0f; + float sum1 = 0.0f; + + // Sum over j < row + // For row <= 32: each lane handles at most 1 element + // For row > 32: each lane handles at most 2 elements + if (lane < row) { + const float a_val = sA[row * 64 + lane]; + sum0 = a_val * sXt[col0 * 65 + lane]; + sum1 = a_val * sXt[col1 * 65 + lane]; + } + if (row > WARP_SIZE) { + const int j2 = lane + WARP_SIZE; + if (j2 < row) { + const float a_val2 = sA[row * 64 + j2]; + sum0 += a_val2 * sXt[col0 * 65 + j2]; + sum1 += a_val2 * sXt[col1 * 65 + j2]; + } + } + + // Warp-level reduction + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + + // Lane 0 computes and stores the result + if (lane == 0) { + const float a_diag = sA[row * 64 + row]; + const float inv_diag = 1.0f / a_diag; + sXt[col0 * 65 + row] = (sXt[col0 * 65 + row] - sum0) * inv_diag; + sXt[col1 * 65 + row] = (sXt[col1 * 65 + row] - sum1) * inv_diag; + } + + // Sync within warp to ensure writes are visible before next row reads + __syncwarp(); + } + + __syncthreads(); + + // Cooperative write of results back (transpose sXt to X) + #pragma unroll 4 + for (int i = tid; i < 64 * 64; i += 1024) { + const int row = i / 64; + const int col = i % 64; + X_batch[row * 64 + col] = sXt[col * 65 + row]; + } +} + +static __global__ void solve_tri_f32_128x128_opt(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n, + const int k) { + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Shared memory with padding to avoid bank conflicts + // Layout: sA[128][128] + sXt[128][129] + extern __shared__ char smem_raw[]; + float * sA = (float *)smem_raw; // 128×128 (zero-initialized for unused parts) + float * sXt = sA + 128 * 128; // 128×129 (padded) + + const int tid = lane + warp_id * WARP_SIZE; + + // Zero-initialize shared memory first (important for variable n, k) + #pragma unroll 16 + for (int i = tid; i < 128 * 128; i += 1024) { + sA[i] = 0.0f; + } + #pragma unroll 16 + for (int i = tid; i < 128 * 129; i += 1024) { + sXt[i] = 0.0f; + } + __syncthreads(); + + // Cooperative load of A matrix (n×n elements) + for (int i = tid; i < n * n; i += 1024) { + const int row = i / n; + const int col = i % n; + sA[row * 128 + col] = A_batch[row * n + col]; + } + + // Cooperative load of B matrix transposed into sXt + // sXt[col * 129 + row] = B[row * k + col] + for (int i = tid; i < n * k; i += 1024) { + const int row = i / k; + const int col = i % k; + sXt[col * 129 + row] = B_batch[row * k + col]; + } + + __syncthreads(); + + // Each warp handles columns: col_base to col_base+3 + // But only process if col < k + const int col_base = warp_id * 4; + + // Forward substitution with all columns processed in parallel + for (int row = 0; row < n; ++row) { + float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f; + + // Sum over j < row - each lane handles multiple elements + for (int j = lane; j < row; j += WARP_SIZE) { + const float a_val = sA[row * 128 + j]; + if (col_base + 0 < k) sum0 += a_val * sXt[(col_base + 0) * 129 + j]; + if (col_base + 1 < k) sum1 += a_val * sXt[(col_base + 1) * 129 + j]; + if (col_base + 2 < k) sum2 += a_val * sXt[(col_base + 2) * 129 + j]; + if (col_base + 3 < k) sum3 += a_val * sXt[(col_base + 3) * 129 + j]; + } + + // Warp-level reduction + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + sum2 = warp_reduce_sum(sum2); + sum3 = warp_reduce_sum(sum3); + + // Lane 0 computes and stores the result + if (lane == 0) { + const float inv_diag = 1.0f / sA[row * 128 + row]; + if (col_base + 0 < k) { + sXt[(col_base + 0) * 129 + row] = (sXt[(col_base + 0) * 129 + row] - sum0) * inv_diag; + } + if (col_base + 1 < k) { + sXt[(col_base + 1) * 129 + row] = (sXt[(col_base + 1) * 129 + row] - sum1) * inv_diag; + } + if (col_base + 2 < k) { + sXt[(col_base + 2) * 129 + row] = (sXt[(col_base + 2) * 129 + row] - sum2) * inv_diag; + } + if (col_base + 3 < k) { + sXt[(col_base + 3) * 129 + row] = (sXt[(col_base + 3) * 129 + row] - sum3) * inv_diag; + } + } + + __syncwarp(); + } + + __syncthreads(); + + // Cooperative write of results back (transpose sXt to X) + for (int i = tid; i < n * k; i += 1024) { + const int row = i / k; + const int col = i % k; + X_batch[row * k + col] = sXt[col * 129 + row]; + } +} + +static __global__ void solve_tri_f32_256x256_tiled(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n, + const int k) { + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + // Tiled approach using 64×64 tiles to fit in shared memory + constexpr int TILE_SIZE = 64; + + extern __shared__ char smem_raw[]; + float * sA_tile = (float *)smem_raw; // 64×64 = 16KB + float * sXt_tile = sA_tile + TILE_SIZE * TILE_SIZE; // 64×65 = 16.25KB (padded) + float * sA_off = sXt_tile + TILE_SIZE * (TILE_SIZE+1); // 64×64 = 16KB (for off-diagonal blocks) + + const int tid = lane + warp_id * WARP_SIZE; + + // Initialize X = B (we'll solve in-place conceptually, using global memory) + for (int i = tid; i < n * k; i += 1024) { + X_batch[i] = B_batch[i]; + } + __syncthreads(); + + // Process tile-by-tile along the diagonal + for (int tile_row = 0; tile_row < n; tile_row += TILE_SIZE) { + const int tile_n = min(TILE_SIZE, n - tile_row); // Actual rows in this tile + + // Zero-init and load diagonal tile of A + for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) { + sA_tile[i] = 0.0f; + } + __syncthreads(); + + for (int i = tid; i < tile_n * tile_n; i += 1024) { + int local_row = i / tile_n; + int local_col = i % tile_n; + sA_tile[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + tile_row + local_col]; + } + __syncthreads(); + + // For each column tile of X + for (int tile_col = 0; tile_col < k; tile_col += TILE_SIZE) { + const int tile_k = min(TILE_SIZE, k - tile_col); // Actual columns in this tile + + // Zero-init and load X tile transposed + for (int i = tid; i < TILE_SIZE * (TILE_SIZE+1); i += 1024) { + sXt_tile[i] = 0.0f; + } + __syncthreads(); + + for (int i = tid; i < tile_n * tile_k; i += 1024) { + int local_row = i / tile_k; + int local_col = i % tile_k; + sXt_tile[local_col * (TILE_SIZE+1) + local_row] = + X_batch[(tile_row + local_row) * k + tile_col + local_col]; + } + __syncthreads(); + + // Apply updates from previous tile rows + for (int prev_tile = 0; prev_tile < tile_row; prev_tile += TILE_SIZE) { + const int prev_n = min(TILE_SIZE, n - prev_tile); + + // Zero-init and load off-diagonal block + for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) { + sA_off[i] = 0.0f; + } + __syncthreads(); + + for (int i = tid; i < tile_n * prev_n; i += 1024) { + int local_row = i / prev_n; + int local_col = i % prev_n; + sA_off[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + prev_tile + local_col]; + } + __syncthreads(); + + // Update: X_tile -= A_off @ X_prev + int col0 = warp_id * 2; + int col1 = warp_id * 2 + 1; + + for (int row = 0; row < tile_n; row++) { + float sum0 = 0.0f, sum1 = 0.0f; + + for (int j = lane; j < prev_n; j += WARP_SIZE) { + float a_val = sA_off[row * TILE_SIZE + j]; + if (col0 < tile_k) { + float x_prev0 = X_batch[(prev_tile + j) * k + tile_col + col0]; + sum0 += a_val * x_prev0; + } + if (col1 < tile_k) { + float x_prev1 = X_batch[(prev_tile + j) * k + tile_col + col1]; + sum1 += a_val * x_prev1; + } + } + + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + + if (lane == 0) { + if (col0 < tile_k) { + sXt_tile[col0 * (TILE_SIZE+1) + row] -= sum0; + } + if (col1 < tile_k) { + sXt_tile[col1 * (TILE_SIZE+1) + row] -= sum1; + } + } + __syncwarp(); + } + __syncthreads(); + } + + // Solve the diagonal tile + int col0 = warp_id * 2; + int col1 = warp_id * 2 + 1; + + for (int row = 0; row < tile_n; ++row) { + float sum0 = 0.0f, sum1 = 0.0f; + + if (lane < row) { + float a_val = sA_tile[row * TILE_SIZE + lane]; + if (col0 < tile_k) sum0 = a_val * sXt_tile[col0 * (TILE_SIZE+1) + lane]; + if (col1 < tile_k) sum1 = a_val * sXt_tile[col1 * (TILE_SIZE+1) + lane]; + } + if (row > WARP_SIZE) { + int j2 = lane + WARP_SIZE; + if (j2 < row) { + float a_val2 = sA_tile[row * TILE_SIZE + j2]; + if (col0 < tile_k) sum0 += a_val2 * sXt_tile[col0 * (TILE_SIZE+1) + j2]; + if (col1 < tile_k) sum1 += a_val2 * sXt_tile[col1 * (TILE_SIZE+1) + j2]; + } + } + + sum0 = warp_reduce_sum(sum0); + sum1 = warp_reduce_sum(sum1); + + if (lane == 0) { + float inv_diag = 1.0f / sA_tile[row * TILE_SIZE + row]; + if (col0 < tile_k) { + sXt_tile[col0 * (TILE_SIZE+1) + row] = + (sXt_tile[col0 * (TILE_SIZE+1) + row] - sum0) * inv_diag; + } + if (col1 < tile_k) { + sXt_tile[col1 * (TILE_SIZE+1) + row] = + (sXt_tile[col1 * (TILE_SIZE+1) + row] - sum1) * inv_diag; + } + } + __syncwarp(); + } + __syncthreads(); + + // Write solved tile back to global memory + for (int i = tid; i < tile_n * tile_k; i += 1024) { + int local_row = i / tile_k; + int local_col = i % tile_k; + X_batch[(tile_row + local_row) * k + tile_col + local_col] = + sXt_tile[local_col * (TILE_SIZE+1) + local_row]; + } + __syncthreads(); + } + } } -// ====================== -// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction -// ====================== // When ncols_template == 0 the bounds for the loops in this function are not // known and can't be unrolled. As we want to keep pragma unroll for all other // cases we supress the clang transformation warning here. @@ -88,7 +535,9 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, # pragma clang diagnostic push # pragma clang diagnostic ignored "-Wpass-failed" #endif // __clang__ -template +// Template parameters: n_template/k_template are the matrix dimensions when known at compile time (0 = runtime) +// threads_y_template is the number of threads in y dimension (max 32 to stay within 1024 thread limit) +template static __global__ void solve_tri_f32_fast(const float * __restrict__ A, const float * __restrict__ B, float * __restrict__ X, @@ -103,14 +552,10 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, const int k_arg) { const int n = n_template == 0 ? n_arg : n_template; const int k = k_template == 0 ? k_arg : k_template; + const int threads_y = threads_y_template == 0 ? blockDim.y : threads_y_template; const int batch_idx = blockIdx.x; const int lane = threadIdx.x; - const int col_idx = threadIdx.y; - - if (col_idx >= k) { - return; - } const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); const int64_t i02 = i02_i03.y; @@ -121,58 +566,94 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; + __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; const int offset = threadIdx.x + threadIdx.y * blockDim.x; + const int block_threads = blockDim.x * blockDim.y; + // Load A matrix into shared memory #pragma unroll - for (int i = 0; i < n * n; i += k * WARP_SIZE) { - const int i0 = i + offset; + for (int i = 0; i < n * n; i += block_threads) { + int i0 = i + offset; if (i0 < n * n) { sA[i0] = A_batch[i0]; } } + const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; + const int cols_per_thread = (k + threads_y - 1) / threads_y; + + // Load B matrix into shared memory (transposed as sXt) + // Each thread handles multiple columns when k > threads_y + for (int c = 0; c < cols_per_thread; c++) { + const int col_idx = threadIdx.y + c * threads_y; + if (col_idx < k) { +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; + } + } + } + } + __syncthreads(); - float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; - float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; - - const int half = WARP_SIZE; - const int nrows_low = (n < half) ? n : half; + // Solve for each column this thread handles + for (int c = 0; c < cols_per_thread; c++) { + const int col_idx = threadIdx.y + c * threads_y; + if (col_idx >= k) { + continue; + } #pragma unroll - for (int row = 0; row < nrows_low; ++row) { - float sum = 0.0f; - if (lane < row) { - sum += sA[row * n + lane] * x_low; - } - sum = warp_reduce_sum(sum); + for (int row = 0; row < n; ++row) { + float sum = 0.0f; - if (lane == row) { - x_low = (x_low - sum) / sA[row * n + row]; + { + int j = lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + if (row >= WARP_SIZE) { + int j = WARP_SIZE + lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + + sum = warp_reduce_sum(sum); + + if (lane == 0) { + const float b_val = sXt[col_idx * n + row]; + const float a_diag = sA[row * n + row]; + // no safeguards for division by zero because that indicates corrupt + // data anyway + sXt[col_idx * n + row] = (b_val - sum) / a_diag; + } + } + + // Sync between columns to ensure writes are visible + if (c + 1 < cols_per_thread) { + __syncwarp(); } } -#pragma unroll - for (int row = half; row < n; ++row) { - float sum = sA[row * n + lane] * x_low; - const int j = half + lane; - if (j < row) { - sum += sA[row * n + j] * x_high; - } - sum = warp_reduce_sum(sum); - - if (lane == row - half) { - x_high = (x_high - sum) / sA[row * n + row]; - } - } + __syncthreads(); + // Write results back + for (int c = 0; c < cols_per_thread; c++) { + const int col_idx = threadIdx.y + c * threads_y; + if (col_idx < k) { #pragma unroll - for (int rr = 0; rr < 2; ++rr) { - const int row = rr * WARP_SIZE + lane; - if (row < n) { - const float val = (row < half) ? x_low : x_high; - X_batch[row * k + col_idx] = val; + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + } + } } } } @@ -180,6 +661,76 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, # pragma clang diagnostic pop #endif // __clang__ +// cuBLAS batched TRSM fallback for larger matrices or as robust path +// Solves A * X = B where A is lower triangular +// This function modifies X in-place (X should be initialized with B) +static void solve_tri_f32_cublas( + ggml_backend_cuda_context & ctx, + const float * A, + float * X, // Input: B, Output: solution X (in-place) + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t nb02, + size_t nb03, + size_t nb2, + size_t nb3, + cudaStream_t stream +) { + const int64_t total_batches = ne02 * ne03; + + // Allocate pointer arrays on device + ggml_cuda_pool_alloc A_ptrs(ctx.pool(), total_batches); + ggml_cuda_pool_alloc X_ptrs(ctx.pool(), total_batches); + + // Set up pointer arrays on device (CUDA graph compatible) + { + const int block_size = 256; + const int grid_size = (total_batches + block_size - 1) / block_size; + setup_trsm_batch_pointers<<>>( + A, X, + A_ptrs.get(), X_ptrs.get(), + ne02, total_batches, + nb02, nb03, nb2, nb3 + ); + CUDA_CHECK(cudaGetLastError()); + } + + // Get cuBLAS handle and set stream + cublasHandle_t handle = ctx.cublas_handle(); + cublasSetStream(handle, stream); + + // Save current math mode and set to default for accuracy + // (TF32 can cause numerical issues with triangular solves) + cublasMath_t prev_math_mode; + cublasGetMathMode(handle, &prev_math_mode); + cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH); + + const float alpha = 1.0f; + + cublasStatus_t status = cublasStrsmBatched( + handle, + CUBLAS_SIDE_RIGHT, // A is on the right: X * A = B + CUBLAS_FILL_MODE_UPPER, // A^T is upper (since A is lower in row-major) + CUBLAS_OP_N, // No additional transpose + CUBLAS_DIAG_NON_UNIT, // Diagonal is not assumed to be 1 + k, // m: rows of X^T (columns of X) + n, // n: columns of X^T (rows of X) = size of A + &alpha, + (const float **)A_ptrs.get(), n, // lda = n (leading dimension) + (float **)X_ptrs.get(), k, // ldb = k (leading dimension of X^T) + total_batches + ); + + // Restore previous math mode + cublasSetMathMode(handle, prev_math_mode); + + if (status != CUBLAS_STATUS_SUCCESS) { + GGML_LOG_ERROR("cuBLAS batched TRSM failed: %d\n", (int)status); + } +} + static void solve_tri_f32_cuda(const float * A, const float * B, float * X, @@ -195,81 +746,133 @@ static void solve_tri_f32_cuda(const float * A, size_t nb3, cudaStream_t stream) { const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); - dim3 threads(WARP_SIZE, k); - dim3 grid(ne02 * ne03); + dim3 grid(ne02 * ne03); + + // Handle large matrices first (256×256 and 65-128 range) + + // Route sizes 65-256 to the tiled kernel + if (n > 64 || k > 64) { + // Use the tiled kernel which works for any size up to 256 + // and only requires ~48KB shared memory (within standard limits) + dim3 threads_256(WARP_SIZE, 32); // 1024 threads + // Shared memory: 64×64 + 64×65 + 64×64 = 16KB + 16.25KB + 16KB = ~48KB + const size_t smem_size = (64 * 64 + 64 * 65 + 64 * 64) * sizeof(float); + + // Configure extended shared memory for this kernel + static bool smem_configured_tiled = false; + if (!smem_configured_tiled) { + cudaFuncSetAttribute(solve_tri_f32_256x256_tiled, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + smem_configured_tiled = true; + } + + solve_tri_f32_256x256_tiled<<>>( + A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + return; + } + + // Limit threads_y to 32 to ensure we don't exceed 1024 threads per block (32 * 32 = 1024) + const int threads_y = k <= 32 ? k : 32; + dim3 threads(WARP_SIZE, threads_y); + if (n == 64) { switch (k) { + case 64: + { + // Use optimized kernel for n=64, k=64 case (common in Qwen3 Next DeltaNet) + // Block config: 32x32 = 1024 threads (32 warps) + dim3 threads_64x64(WARP_SIZE, 32); + solve_tri_f32_64x64_opt + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3); + } + break; + case 48: + // k=48 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2) + solve_tri_f32_fast<64, 48, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 40: + // k=40 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2) + solve_tri_f32_fast<64, 40, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; case 32: - solve_tri_f32_fast<64, 32> + solve_tri_f32_fast<64, 32, 32> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 16: - solve_tri_f32_fast<64, 16> + solve_tri_f32_fast<64, 16, 16> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 14: - solve_tri_f32_fast<64, 14> + solve_tri_f32_fast<64, 14, 14> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 12: - solve_tri_f32_fast<64, 12> + solve_tri_f32_fast<64, 12, 12> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 10: - solve_tri_f32_fast<64, 10> + solve_tri_f32_fast<64, 10, 10> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 8: - solve_tri_f32_fast<64, 8> + solve_tri_f32_fast<64, 8, 8> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 6: - solve_tri_f32_fast<64, 6> + solve_tri_f32_fast<64, 6, 6> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 4: - solve_tri_f32_fast<64, 4> + solve_tri_f32_fast<64, 4, 4> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 2: - solve_tri_f32_fast<64, 2> + solve_tri_f32_fast<64, 2, 2> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; case 1: - solve_tri_f32_fast<64, 1> + solve_tri_f32_fast<64, 1, 1> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); break; default: - solve_tri_f32_fast<0, 0> + solve_tri_f32_fast<0, 0, 0> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); } } else { // run general case - solve_tri_f32_fast<0, 0> + solve_tri_f32_fast<0, 0, 0> <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); } } void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular) - const ggml_tensor * src1 = dst->src[1]; // B (n×k) + const ggml_tensor * src0 = dst->src[0]; // A (triangular n x n matrix) + const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) - ggml_is_contiguous(src0); - ggml_is_contiguous(src1); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); - const int64_t n = src0->ne[0]; - const int64_t k = src1->ne[0]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; + const int64_t n = src0->ne[0]; + const int64_t k = src1->ne[0]; - if (n <= MAX_N_FAST && k <= MAX_K_FAST) { - solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, - src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), - src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), - dst->nb[3] / sizeof(float), ctx.stream()); - } else { - solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, - ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), - src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), - dst->nb[3] / sizeof(float), ctx.stream()); - } + const int64_t total_batches = src0->ne[2] * src0->ne[3]; + const size_t X_size = n * k * total_batches * sizeof(float); + + // Copy B to X (cuBLAS solves in-place) + CUDA_CHECK(cudaMemcpyAsync( + dst->data, src1->data, X_size, + cudaMemcpyDeviceToDevice, ctx.stream() + )); + + solve_tri_f32_cublas( + ctx, + (const float *) src0->data, + (float *) dst->data, + n, k, + src0->ne[2], src0->ne[3], + src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float), + ctx.stream() + ); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eb3ae72eaa..705261acbe 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1028,6 +1028,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "DELTA_NET", "UNARY", @@ -1045,7 +1046,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1137,6 +1138,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "delta_net(q, k, v, g, beta, state)", "unary(x)", @@ -1154,7 +1156,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6093,6 +6095,63 @@ struct ggml_tensor * ggml_solve_tri( return result; } +// delta_net + +struct ggml_tensor * ggml_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_k = q->ne[0]; + const int64_t n_tokens = q->ne[1]; + const int64_t H_k = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[2]; + + GGML_UNUSED(S_k); + + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs); + GGML_ASSERT(g->ne[0] == n_tokens && g->ne[2] == H_k && g->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[3] == n_seqs); + GGML_ASSERT(H_k == H_v); + + const int64_t output_size = S_v * H_v * n_tokens * n_seqs; + const int64_t state_size = S_v * S_v * H_v * n_seqs; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size); + + result->op = GGML_OP_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/src/models/models.h b/src/models/models.h index ffb36acc61..82d0cbe890 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -460,14 +460,14 @@ private: ggml_tensor * diag_mask, int il); - ggml_tensor * build_delta_net_autoregressive( + ggml_tensor * build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state, - int il); + int il); ggml_tensor * build_norm_gated( ggml_tensor * input, diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 775b3135d3..d441a204b5 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -426,6 +426,78 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( return ggml_concat(ctx0, flat_output, flat_state, 0); } +ggml_tensor * llm_build_qwen3next::build_delta_net_fused( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 1, 3, 0, 2), n_tokens, 1, H_k, n_seqs); + beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 1, 2, 0, 3), 1, n_tokens, H_k, n_seqs); + + cb(q, "q_fused", il); + cb(k, "k_fused", il); + cb(v, "v_fused", il); + cb(g, "g_fused", il); + cb(beta, "beta_fused", il); + + ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state); + cb(fused_result, "delta_net_fused_raw", il); + + const int64_t output_size = S_v * H_v * n_tokens * n_seqs; + const int64_t state_size = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output_4d = ggml_view_4d(ctx0, fused_result, + S_v, H_v, n_tokens, n_seqs, + S_v * ggml_element_size(fused_result), + S_v * H_v * ggml_element_size(fused_result), + S_v * H_v * n_tokens * ggml_element_size(fused_result), + 0); + cb(output_4d, "fused_output_4d", il); + + ggml_tensor * flat_output = ggml_cont_1d(ctx0, output_4d, output_size); + cb(flat_output, "fused_flat_output", il); + + ggml_tensor * flat_state = ggml_view_1d(ctx0, fused_result, state_size, + output_size * ggml_element_size(fused_result)); + cb(flat_state, "fused_flat_state", il); + + ggml_tensor * result = ggml_concat(ctx0, flat_output, flat_state, 0); + + return result; +} + ggml_tensor * llm_build_qwen3next::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -737,13 +809,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - ggml_tensor * attn_out; - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); - } + ggml_tensor * attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il); cb(attn_out, "attn_out", il); // The tensors were concatenated 1d, so we need to extract them 1d as well @@ -844,7 +910,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int cur = moe_out; } } else { - // Dense FFN branch (not currently used I believe) + // Dense FFN branch cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, From 0a192937a14c28b17ee2cc81bb4875be40053d48 Mon Sep 17 00:00:00 2001 From: hauhaut Date: Tue, 16 Dec 2025 02:08:50 +0100 Subject: [PATCH 2/3] qwen3next: trim comments --- ggml/src/ggml-cuda/delta-net.cu | 6 ++---- src/models/qwen3next.cpp | 35 --------------------------------- 2 files changed, 2 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index 634780d5dc..acac648c75 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -1418,17 +1418,15 @@ static void delta_net_f32_cuda( // - Vectors (Q,K,V,KBeta,VBeta,KCumdecay,VPrime,VNew,Out): 9 × HEAD_DIM × sizeof(float) = 4608 bytes // - Warp scratch: 16 × sizeof(float) = 64 bytes // Total: 65536 + 4608 + 64 = 70208 bytes (~68.6KB) - // Note: __shared__ scalars (decay, beta, etc.) are static, not dynamic + // __shared__ scalars (decay, beta, etc.) are static, not dynamic constexpr size_t state_bytes = 128 * 128 * sizeof(float); // 64KB constexpr size_t vector_bytes = 9 * 128 * sizeof(float); // 4.5KB constexpr size_t warp_scratch_bytes = 16 * sizeof(float); // 64B constexpr size_t blackwell_smem_size = state_bytes + vector_bytes + warp_scratch_bytes; - // Sanity check: ensure we allocated enough static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch"); - // Check for A/B comparison mode - // Use a function-local static for thread-safe lazy initialization + // A/B comparison mode (set GGML_CUDA_DELTA_NET_AB=1) static const bool ab_mode = []() { const char* env = std::getenv("GGML_CUDA_DELTA_NET_AB"); if (env != nullptr) { diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index d441a204b5..af43256c6d 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -33,12 +33,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { - // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); } else { - // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); } @@ -47,37 +44,28 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // Residual connection cur = ggml_add(ctx0, cur, inpSA); cb(cur, "attn_residual", il); - // Save the tensor before post-attention norm for residual connection ggml_tensor * ffn_residual = cur; - // Post-attention norm ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); cb(attn_post_norm, "attn_post_norm", il); - // FFN layer (MoE or dense) - without residual connection cur = build_layer_ffn(attn_post_norm, il); cb(cur, "ffn_out", il); - // Residual connection for FFN - add to the tensor from before post_attention_layernorm cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); - // Input for next layer inpL = cur; } cur = inpL; - // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); res->t_embd = cur; - // LM head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); @@ -517,16 +505,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention - - // Qwen3Next uses a single Q projection that outputs query + gate ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); cb(Qcur_full, "Qcur_full", il); Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); - // Split Q projection into query and gate - // The split should be along dimension 0 (the feature dimension) ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); ggml_tensor * gate = @@ -535,11 +518,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cb(Qcur, "Qcur", il); cb(gate, "gate", il); - // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur_reshaped", il); - // Apply Q normalization Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -549,18 +530,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); cb(gate, "gate_reshaped", il); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // Apply RoPE Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -575,7 +553,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // Attention computation const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, @@ -861,9 +838,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( } ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { - // Check if this is an MoE layer if (model.layers[il].ffn_gate_inp != nullptr) { - // MoE branch ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, @@ -873,7 +848,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); - // Add shared experts if present - following Qwen3Next reference implementation if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, @@ -884,23 +858,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int LLM_FFN_SILU, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); - // Apply shared expert gating as in the reference implementation - // The shared expert has its own gate that is sigmoided - // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); cb(shared_gate, "shared_expert_gate", il); - // Apply sigmoid to the gate shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // The gate needs to be broadcast to match the dimensions of ffn_shexp - // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] - // We need to repeat the gate along the feature dimension shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); cb(shared_gate, "shared_expert_gate_broadcast", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il); @@ -910,7 +876,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int cur = moe_out; } } else { - // Dense FFN branch cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, From 128a6c28316155786ebbf41b58bd54cda62d519a Mon Sep 17 00:00:00 2001 From: hauhaut Date: Tue, 16 Dec 2025 18:21:29 +0100 Subject: [PATCH 3/3] ggml-cpu: add DELTA_NET backend + tests --- ggml/src/ggml-cpu/ggml-cpu.c | 5 ++ ggml/src/ggml-cpu/ops.cpp | 133 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + src/models/models.h | 9 +++ tests/test-backend-ops.cpp | 33 +++++++++ 5 files changed, 181 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a59b518938..02a70c9b47 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2014,6 +2014,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rwkv_wkv7(params, tensor); } break; + case GGML_OP_DELTA_NET: + { + ggml_compute_forward_delta_net(params, tensor); + } break; case GGML_OP_SOLVE_TRI: { ggml_compute_forward_solve_tri(params, tensor); @@ -2339,6 +2343,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: + case GGML_OP_DELTA_NET: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3032783971..c22c362ea7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10091,6 +10091,139 @@ void ggml_compute_forward_rwkv_wkv7( } } +// ggml_compute_forward_delta_net + +static void ggml_compute_forward_delta_net_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const ggml_tensor * src3 = dst->src[3]; + const ggml_tensor * src4 = dst->src[4]; + const ggml_tensor * src5 = dst->src[5]; + + const int64_t head_dim = src0->ne[0]; + const int64_t n_tokens = src0->ne[1]; + const int64_t n_heads = src0->ne[2]; + const int64_t n_seqs = src0->ne[3]; + + const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs; + + const float * q_data = (const float *) src0->data; + const float * k_data = (const float *) src1->data; + const float * v_data = (const float *) src2->data; + const float * g_data = (const float *) src3->data; + const float * beta_data = (const float *) src4->data; + const float * state_in = (const float *) src5->data; + float * out_data = (float *) dst->data; + float * state_out = out_data + output_size; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t total_heads = n_heads * n_seqs; + const int64_t heads_per_thread = (total_heads + nth - 1) / nth; + const int64_t h_start = ith * heads_per_thread; + const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads; + + const float eps = 1e-12f; + const float scale = 1.0f / sqrtf((float)head_dim); + + float * v_new_buf = (float *)malloc(head_dim * sizeof(float)); + if (!v_new_buf) { + return; + } + + for (int64_t h_idx = h_start; h_idx < h_end; h_idx++) { + const int64_t batch_idx = h_idx / n_heads; + const int64_t head_idx = h_idx % n_heads; + + const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens); + const int64_t qkv_token_stride = head_dim; + const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens; + const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim); + const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim; + const int64_t out_token_stride = head_dim * n_heads; + + for (int64_t i = 0; i < head_dim * head_dim; i++) { + state_out[state_head_offset + i] = state_in[state_head_offset + i]; + } + + float * state = state_out + state_head_offset; + + for (int64_t t = 0; t < n_tokens; t++) { + const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride; + const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride; + const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride; + + float g_val = g_data[g_head_offset + t]; + float beta_raw = beta_data[g_head_offset + t]; + + float q_norm_sq = 0.0f, k_norm_sq = 0.0f; + for (int64_t i = 0; i < head_dim; i++) { + q_norm_sq += q_t[i] * q_t[i]; + k_norm_sq += k_t[i] * k_t[i]; + } + float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps); + float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps); + + float beta_val = 1.0f / (1.0f + expf(-beta_raw)); + float decay = expf(fminf(g_val, 50.0f)); + + float attn_score = 0.0f; + for (int64_t i = 0; i < head_dim; i++) { + attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale); + } + + float * out_t = out_data + out_head_offset + t * out_token_stride; + + for (int64_t row = 0; row < head_dim; row++) { + float v_prime = 0.0f; + float out_val = 0.0f; + + for (int64_t col = 0; col < head_dim; col++) { + float k_col = k_t[col] * k_norm_inv; + float q_col = q_t[col] * q_norm_inv * scale; + float s = state[row + col * head_dim]; + + v_prime += s * k_col * beta_val * decay; + out_val += s * q_col * decay; + } + + float v_new = v_t[row] * beta_val - v_prime; + v_new_buf[row] = v_new; + out_t[row] = out_val + v_new * attn_score; + } + + for (int64_t col = 0; col < head_dim; col++) { + float k_col = k_t[col] * k_norm_inv; + for (int64_t row = 0; row < head_dim; row++) { + float s = state[row + col * head_dim]; + s = decay * s + v_new_buf[row] * k_col; + state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f); + } + } + } + } + + free(v_new_buf); +} + +void ggml_compute_forward_delta_net( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + ggml_compute_forward_delta_net_f32(params, dst); + break; + default: + GGML_ABORT("fatal error"); + } +} + // ggml_compute_forward_map_custom1 void ggml_compute_forward_map_custom1( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..e6fd13eb47 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/src/models/models.h b/src/models/models.h index 82d0cbe890..20f1366a62 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -469,6 +469,15 @@ private: ggml_tensor * state, int il); + ggml_tensor * build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 416218b5b8..d46011496e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3550,6 +3550,34 @@ struct test_rwkv_wkv7 : public test_case { } }; +// GGML_OP_DELTA_NET +struct test_delta_net : public test_case { + const ggml_type type; + + const int64_t n_heads; + const int64_t head_dim; + const int64_t n_tokens; + const int64_t n_seqs; + + std::string vars() override { + return VARS_TO_STR5(type, n_heads, head_dim, n_tokens, n_seqs); + } + + test_delta_net(ggml_type type = GGML_TYPE_F32, + int64_t n_heads = 8, int64_t head_dim = 64, int64_t n_tokens = 32, int64_t n_seqs = 2) + : type(type), n_heads(n_heads), head_dim(head_dim), n_tokens(n_tokens), n_seqs(n_seqs) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs); + ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs); + ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs); + ggml_tensor * g = ggml_new_tensor_4d(ctx, type, n_tokens, 1, n_heads, n_seqs); + ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, n_tokens, n_heads, n_seqs); + ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_dim, head_dim * n_heads, 1, n_seqs); + return ggml_delta_net(ctx, q, k, v, g, beta, state); + } +}; + // GGML_OP_MUL_MAT struct test_mul_mat : public test_case { const ggml_type type_a; @@ -7322,6 +7350,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 1, 1)); + test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 1)); + test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 2)); + test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 128, 2)); + #if 0 // > 4GB A matrix. Too slow to be enabled by default. test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));