This commit is contained in:
hauhaut 2025-12-16 21:47:05 -06:00 committed by GitHub
commit 44765b04b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 2686 additions and 177 deletions

View File

@ -551,6 +551,7 @@ extern "C" {
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_DELTA_NET,
GGML_OP_UNARY,
@ -2460,6 +2461,15 @@ extern "C" {
bool lower,
bool uni);
GGML_API struct ggml_tensor * ggml_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q, // [S_k, n_tokens, H_k, n_seqs] - Query (pre-permuted)
struct ggml_tensor * k, // [S_k, n_tokens, H_k, n_seqs] - Key (pre-permuted)
struct ggml_tensor * v, // [S_v, n_tokens, H_v, n_seqs] - Value (pre-permuted)
struct ggml_tensor * g, // [n_tokens, 1, H_k, n_seqs] - Gate logits (pre-permuted)
struct ggml_tensor * beta, // [1, n_tokens, H_k, n_seqs] - Beta (pre-permuted)
struct ggml_tensor * state); // [S_v, S_v*H_v, 1, n_seqs] - Recurrent state
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -2014,6 +2014,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rwkv_wkv7(params, tensor);
} break;
case GGML_OP_DELTA_NET:
{
ggml_compute_forward_delta_net(params, tensor);
} break;
case GGML_OP_SOLVE_TRI:
{
ggml_compute_forward_solve_tri(params, tensor);
@ -2339,6 +2343,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
case GGML_OP_DELTA_NET:
{
n_tasks = n_threads;
} break;

View File

@ -10091,6 +10091,139 @@ void ggml_compute_forward_rwkv_wkv7(
}
}
// ggml_compute_forward_delta_net
static void ggml_compute_forward_delta_net_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];
const int64_t head_dim = src0->ne[0];
const int64_t n_tokens = src0->ne[1];
const int64_t n_heads = src0->ne[2];
const int64_t n_seqs = src0->ne[3];
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
const float * q_data = (const float *) src0->data;
const float * k_data = (const float *) src1->data;
const float * v_data = (const float *) src2->data;
const float * g_data = (const float *) src3->data;
const float * beta_data = (const float *) src4->data;
const float * state_in = (const float *) src5->data;
float * out_data = (float *) dst->data;
float * state_out = out_data + output_size;
const int ith = params->ith;
const int nth = params->nth;
const int64_t total_heads = n_heads * n_seqs;
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
const int64_t h_start = ith * heads_per_thread;
const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
const float eps = 1e-12f;
const float scale = 1.0f / sqrtf((float)head_dim);
float * v_new_buf = (float *)malloc(head_dim * sizeof(float));
if (!v_new_buf) {
return;
}
for (int64_t h_idx = h_start; h_idx < h_end; h_idx++) {
const int64_t batch_idx = h_idx / n_heads;
const int64_t head_idx = h_idx % n_heads;
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
const int64_t qkv_token_stride = head_dim;
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
const int64_t out_token_stride = head_dim * n_heads;
for (int64_t i = 0; i < head_dim * head_dim; i++) {
state_out[state_head_offset + i] = state_in[state_head_offset + i];
}
float * state = state_out + state_head_offset;
for (int64_t t = 0; t < n_tokens; t++) {
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
float g_val = g_data[g_head_offset + t];
float beta_raw = beta_data[g_head_offset + t];
float q_norm_sq = 0.0f, k_norm_sq = 0.0f;
for (int64_t i = 0; i < head_dim; i++) {
q_norm_sq += q_t[i] * q_t[i];
k_norm_sq += k_t[i] * k_t[i];
}
float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
float beta_val = 1.0f / (1.0f + expf(-beta_raw));
float decay = expf(fminf(g_val, 50.0f));
float attn_score = 0.0f;
for (int64_t i = 0; i < head_dim; i++) {
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
}
float * out_t = out_data + out_head_offset + t * out_token_stride;
for (int64_t row = 0; row < head_dim; row++) {
float v_prime = 0.0f;
float out_val = 0.0f;
for (int64_t col = 0; col < head_dim; col++) {
float k_col = k_t[col] * k_norm_inv;
float q_col = q_t[col] * q_norm_inv * scale;
float s = state[row + col * head_dim];
v_prime += s * k_col * beta_val * decay;
out_val += s * q_col * decay;
}
float v_new = v_t[row] * beta_val - v_prime;
v_new_buf[row] = v_new;
out_t[row] = out_val + v_new * attn_score;
}
for (int64_t col = 0; col < head_dim; col++) {
float k_col = k_t[col] * k_norm_inv;
for (int64_t row = 0; row < head_dim; row++) {
float s = state[row + col * head_dim];
s = decay * s + v_new_buf[row] * k_col;
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
}
}
}
}
free(v_new_buf);
}
void ggml_compute_forward_delta_net(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
ggml_compute_forward_delta_net_f32(params, dst);
break;
default:
GGML_ABORT("fatal error");
}
}
// ggml_compute_forward_map_custom1
void ggml_compute_forward_map_custom1(

View File

@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -55,6 +55,7 @@
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/delta-net.cuh"
#include "ggml-cuda/tri.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/fill.cuh"
@ -2735,6 +2736,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
case GGML_OP_DELTA_NET:
ggml_cuda_op_delta_net(ctx, dst);
break;
case GGML_OP_FILL:
ggml_cuda_op_fill(ctx, dst);
break;
@ -2904,6 +2908,13 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
#endif
}
if (node->op == GGML_OP_DELTA_NET) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to DELTA_NET recurrent state\n", __func__);
#endif
}
if (!use_cuda_graph) {
break;
}
@ -4632,6 +4643,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_DIAG:
case GGML_OP_SOLVE_TRI:
return true;
case GGML_OP_DELTA_NET:
return op->src[0]->ne[0] <= 256 && op->src[2]->ne[0] <= 256;
default:
return false;

View File

@ -1,86 +1,533 @@
#include "common.cuh"
#include "ggml.h"
#include "solve_tri.cuh"
#include "ggml-cuda.h"
#include <cublas_v2.h>
#define MAX_N_FAST 64
#define MAX_K_FAST 32
#define MAX_K_FAST 64
static __global__ void get_batch_pointers(const float * A,
// Kernel to set up pointer arrays for batched cuBLAS TRSM
// This avoids host-device copy during CUDA graph capture
static __global__ void setup_trsm_batch_pointers(
const float * A,
float * X,
const float ** A_ptrs,
float ** X_ptrs,
int64_t ne02,
int64_t total_batches,
size_t s02,
size_t s03,
size_t s2,
size_t s3) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_batches) {
return;
}
const int64_t ne02,
const int64_t total_batches,
const size_t nb02, // stride for A dim 2 (in floats)
const size_t nb03, // stride for A dim 3 (in floats)
const size_t nb2, // stride for X dim 2 (in floats)
const size_t nb3 // stride for X dim 3 (in floats)
) {
const int64_t batch_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (batch_idx >= total_batches) return;
const int64_t i3 = idx / ne02;
const int64_t i2 = idx % ne02;
// Decompose batch_idx into i02, i03
const int64_t i02 = batch_idx % ne02;
const int64_t i03 = batch_idx / ne02;
A_ptrs[idx] = A + i3 * s03 + i2 * s02;
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
A_ptrs[batch_idx] = A + i02 * nb02 + i03 * nb03;
X_ptrs[batch_idx] = X + i02 * nb2 + i03 * nb3;
}
static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t s02,
size_t s03,
size_t s12,
size_t s13,
size_t s2,
size_t s3,
cudaStream_t stream) {
const float alpha = 1.0f;
const int64_t total_batches = ne02 * ne03;
if (total_batches == 0) {
return;
// Latency-optimized kernel for n=64, k=64 (single-token generation)
static __global__ void solve_tri_f32_64x64_latency(
const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3)
{
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Shared memory: A is 64x64, X is 64x65 (padded for bank conflicts)
__shared__ float sA[64 * 64];
__shared__ float sX[64 * 65];
__shared__ float sDiagInv[64]; // Precomputed 1/diagonal
const int tid = lane + warp_id * WARP_SIZE;
// Cooperative load of A matrix (4096 elements / 512 threads = 8 per thread)
#pragma unroll 8
for (int i = tid; i < 64 * 64; i += 512) {
sA[i] = A_batch[i];
}
// Bulk copy B -> X (contiguous tensors)
if (X != B) {
const int64_t total_elements_BX = n * k * total_batches;
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
// Cooperative load of B matrix into sX with padding
#pragma unroll 8
for (int i = tid; i < 64 * 64; i += 512) {
const int row = i / 64;
const int col = i % 64;
sX[row * 65 + col] = B_batch[i];
}
const int id = ggml_cuda_get_device();
__syncthreads();
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
// Precompute diagonal inverses (first 2 warps handle this)
if (warp_id == 0) {
if (lane < 32) {
sDiagInv[lane] = 1.0f / sA[lane * 64 + lane];
}
}
if (warp_id == 1) {
if (lane < 32) {
sDiagInv[32 + lane] = 1.0f / sA[(32 + lane) * 64 + (32 + lane)];
}
}
const float ** A_ptrs_dev = A_ptrs_alloc.get();
float ** X_ptrs_dev = X_ptrs_alloc.get();
__syncthreads();
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
total_batches, s02, s03, s2, s3);
// Each warp handles 4 columns: cols = warp_id*4 to warp_id*4+3
const int col_base = warp_id * 4;
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
#pragma unroll 1
for (int row = 0; row < 64; ++row) {
float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
// Yes, this is necessary, without this we get RMSE errors
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
if (row > 0) {
for (int j = lane; j < row; j += WARP_SIZE) {
const float a_val = sA[row * 64 + j];
sum0 += a_val * sX[j * 65 + col_base + 0];
sum1 += a_val * sX[j * 65 + col_base + 1];
sum2 += a_val * sX[j * 65 + col_base + 2];
sum3 += a_val * sX[j * 65 + col_base + 3];
}
}
// revert to standard mode from common.cuh
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
sum2 = warp_reduce_sum(sum2);
sum3 = warp_reduce_sum(sum3);
GGML_UNUSED_VARS(s12, s13);
if (lane == 0) {
const float inv_diag = sDiagInv[row];
sX[row * 65 + col_base + 0] = (sX[row * 65 + col_base + 0] - sum0) * inv_diag;
sX[row * 65 + col_base + 1] = (sX[row * 65 + col_base + 1] - sum1) * inv_diag;
sX[row * 65 + col_base + 2] = (sX[row * 65 + col_base + 2] - sum2) * inv_diag;
sX[row * 65 + col_base + 3] = (sX[row * 65 + col_base + 3] - sum3) * inv_diag;
}
__syncthreads();
}
// Cooperative write results back
#pragma unroll 8
for (int i = tid; i < 64 * 64; i += 512) {
const int row = i / 64;
const int col = i % 64;
X_batch[i] = sX[row * 65 + col];
}
}
static __global__ void solve_tri_f32_64x64_opt(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3) {
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Shared memory: A is 64x64, sXt is 64x65 (padded)
__shared__ float sA[64 * 64];
__shared__ float sXt[64 * 65];
const int tid = lane + warp_id * WARP_SIZE;
// Cooperative load of A matrix (4096 elements / 1024 threads = 4 per thread)
#pragma unroll 4
for (int i = tid; i < 64 * 64; i += 1024) {
sA[i] = A_batch[i];
}
// Cooperative load of B matrix transposed into sXt
// sXt[col * 65 + row] = B[row * 64 + col]
#pragma unroll 4
for (int i = tid; i < 64 * 64; i += 1024) {
const int row = i / 64;
const int col = i % 64;
sXt[col * 65 + row] = B_batch[row * 64 + col];
}
__syncthreads();
// Each warp handles 2 columns: col0 = warp_id*2, col1 = warp_id*2 + 1
const int col0 = warp_id * 2;
const int col1 = warp_id * 2 + 1;
// Forward substitution with all columns processed in parallel
// Each row depends on previous rows, but different columns are independent
#pragma unroll 1
for (int row = 0; row < 64; ++row) {
// Each lane computes partial sum for indices it handles
float sum0 = 0.0f;
float sum1 = 0.0f;
// Sum over j < row
// For row <= 32: each lane handles at most 1 element
// For row > 32: each lane handles at most 2 elements
if (lane < row) {
const float a_val = sA[row * 64 + lane];
sum0 = a_val * sXt[col0 * 65 + lane];
sum1 = a_val * sXt[col1 * 65 + lane];
}
if (row > WARP_SIZE) {
const int j2 = lane + WARP_SIZE;
if (j2 < row) {
const float a_val2 = sA[row * 64 + j2];
sum0 += a_val2 * sXt[col0 * 65 + j2];
sum1 += a_val2 * sXt[col1 * 65 + j2];
}
}
// Warp-level reduction
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
// Lane 0 computes and stores the result
if (lane == 0) {
const float a_diag = sA[row * 64 + row];
const float inv_diag = 1.0f / a_diag;
sXt[col0 * 65 + row] = (sXt[col0 * 65 + row] - sum0) * inv_diag;
sXt[col1 * 65 + row] = (sXt[col1 * 65 + row] - sum1) * inv_diag;
}
// Sync within warp to ensure writes are visible before next row reads
__syncwarp();
}
__syncthreads();
// Cooperative write of results back (transpose sXt to X)
#pragma unroll 4
for (int i = tid; i < 64 * 64; i += 1024) {
const int row = i / 64;
const int col = i % 64;
X_batch[row * 64 + col] = sXt[col * 65 + row];
}
}
static __global__ void solve_tri_f32_128x128_opt(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n,
const int k) {
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Shared memory with padding to avoid bank conflicts
// Layout: sA[128][128] + sXt[128][129]
extern __shared__ char smem_raw[];
float * sA = (float *)smem_raw; // 128×128 (zero-initialized for unused parts)
float * sXt = sA + 128 * 128; // 128×129 (padded)
const int tid = lane + warp_id * WARP_SIZE;
// Zero-initialize shared memory first (important for variable n, k)
#pragma unroll 16
for (int i = tid; i < 128 * 128; i += 1024) {
sA[i] = 0.0f;
}
#pragma unroll 16
for (int i = tid; i < 128 * 129; i += 1024) {
sXt[i] = 0.0f;
}
__syncthreads();
// Cooperative load of A matrix (n×n elements)
for (int i = tid; i < n * n; i += 1024) {
const int row = i / n;
const int col = i % n;
sA[row * 128 + col] = A_batch[row * n + col];
}
// Cooperative load of B matrix transposed into sXt
// sXt[col * 129 + row] = B[row * k + col]
for (int i = tid; i < n * k; i += 1024) {
const int row = i / k;
const int col = i % k;
sXt[col * 129 + row] = B_batch[row * k + col];
}
__syncthreads();
// Each warp handles columns: col_base to col_base+3
// But only process if col < k
const int col_base = warp_id * 4;
// Forward substitution with all columns processed in parallel
for (int row = 0; row < n; ++row) {
float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
// Sum over j < row - each lane handles multiple elements
for (int j = lane; j < row; j += WARP_SIZE) {
const float a_val = sA[row * 128 + j];
if (col_base + 0 < k) sum0 += a_val * sXt[(col_base + 0) * 129 + j];
if (col_base + 1 < k) sum1 += a_val * sXt[(col_base + 1) * 129 + j];
if (col_base + 2 < k) sum2 += a_val * sXt[(col_base + 2) * 129 + j];
if (col_base + 3 < k) sum3 += a_val * sXt[(col_base + 3) * 129 + j];
}
// Warp-level reduction
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
sum2 = warp_reduce_sum(sum2);
sum3 = warp_reduce_sum(sum3);
// Lane 0 computes and stores the result
if (lane == 0) {
const float inv_diag = 1.0f / sA[row * 128 + row];
if (col_base + 0 < k) {
sXt[(col_base + 0) * 129 + row] = (sXt[(col_base + 0) * 129 + row] - sum0) * inv_diag;
}
if (col_base + 1 < k) {
sXt[(col_base + 1) * 129 + row] = (sXt[(col_base + 1) * 129 + row] - sum1) * inv_diag;
}
if (col_base + 2 < k) {
sXt[(col_base + 2) * 129 + row] = (sXt[(col_base + 2) * 129 + row] - sum2) * inv_diag;
}
if (col_base + 3 < k) {
sXt[(col_base + 3) * 129 + row] = (sXt[(col_base + 3) * 129 + row] - sum3) * inv_diag;
}
}
__syncwarp();
}
__syncthreads();
// Cooperative write of results back (transpose sXt to X)
for (int i = tid; i < n * k; i += 1024) {
const int row = i / k;
const int col = i % k;
X_batch[row * k + col] = sXt[col * 129 + row];
}
}
static __global__ void solve_tri_f32_256x256_tiled(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n,
const int k) {
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int warp_id = threadIdx.y;
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
// Tiled approach using 64×64 tiles to fit in shared memory
constexpr int TILE_SIZE = 64;
extern __shared__ char smem_raw[];
float * sA_tile = (float *)smem_raw; // 64×64 = 16KB
float * sXt_tile = sA_tile + TILE_SIZE * TILE_SIZE; // 64×65 = 16.25KB (padded)
float * sA_off = sXt_tile + TILE_SIZE * (TILE_SIZE+1); // 64×64 = 16KB (for off-diagonal blocks)
const int tid = lane + warp_id * WARP_SIZE;
// Initialize X = B (we'll solve in-place conceptually, using global memory)
for (int i = tid; i < n * k; i += 1024) {
X_batch[i] = B_batch[i];
}
__syncthreads();
// Process tile-by-tile along the diagonal
for (int tile_row = 0; tile_row < n; tile_row += TILE_SIZE) {
const int tile_n = min(TILE_SIZE, n - tile_row); // Actual rows in this tile
// Zero-init and load diagonal tile of A
for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) {
sA_tile[i] = 0.0f;
}
__syncthreads();
for (int i = tid; i < tile_n * tile_n; i += 1024) {
int local_row = i / tile_n;
int local_col = i % tile_n;
sA_tile[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + tile_row + local_col];
}
__syncthreads();
// For each column tile of X
for (int tile_col = 0; tile_col < k; tile_col += TILE_SIZE) {
const int tile_k = min(TILE_SIZE, k - tile_col); // Actual columns in this tile
// Zero-init and load X tile transposed
for (int i = tid; i < TILE_SIZE * (TILE_SIZE+1); i += 1024) {
sXt_tile[i] = 0.0f;
}
__syncthreads();
for (int i = tid; i < tile_n * tile_k; i += 1024) {
int local_row = i / tile_k;
int local_col = i % tile_k;
sXt_tile[local_col * (TILE_SIZE+1) + local_row] =
X_batch[(tile_row + local_row) * k + tile_col + local_col];
}
__syncthreads();
// Apply updates from previous tile rows
for (int prev_tile = 0; prev_tile < tile_row; prev_tile += TILE_SIZE) {
const int prev_n = min(TILE_SIZE, n - prev_tile);
// Zero-init and load off-diagonal block
for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) {
sA_off[i] = 0.0f;
}
__syncthreads();
for (int i = tid; i < tile_n * prev_n; i += 1024) {
int local_row = i / prev_n;
int local_col = i % prev_n;
sA_off[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + prev_tile + local_col];
}
__syncthreads();
// Update: X_tile -= A_off @ X_prev
int col0 = warp_id * 2;
int col1 = warp_id * 2 + 1;
for (int row = 0; row < tile_n; row++) {
float sum0 = 0.0f, sum1 = 0.0f;
for (int j = lane; j < prev_n; j += WARP_SIZE) {
float a_val = sA_off[row * TILE_SIZE + j];
if (col0 < tile_k) {
float x_prev0 = X_batch[(prev_tile + j) * k + tile_col + col0];
sum0 += a_val * x_prev0;
}
if (col1 < tile_k) {
float x_prev1 = X_batch[(prev_tile + j) * k + tile_col + col1];
sum1 += a_val * x_prev1;
}
}
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
if (lane == 0) {
if (col0 < tile_k) {
sXt_tile[col0 * (TILE_SIZE+1) + row] -= sum0;
}
if (col1 < tile_k) {
sXt_tile[col1 * (TILE_SIZE+1) + row] -= sum1;
}
}
__syncwarp();
}
__syncthreads();
}
// Solve the diagonal tile
int col0 = warp_id * 2;
int col1 = warp_id * 2 + 1;
for (int row = 0; row < tile_n; ++row) {
float sum0 = 0.0f, sum1 = 0.0f;
if (lane < row) {
float a_val = sA_tile[row * TILE_SIZE + lane];
if (col0 < tile_k) sum0 = a_val * sXt_tile[col0 * (TILE_SIZE+1) + lane];
if (col1 < tile_k) sum1 = a_val * sXt_tile[col1 * (TILE_SIZE+1) + lane];
}
if (row > WARP_SIZE) {
int j2 = lane + WARP_SIZE;
if (j2 < row) {
float a_val2 = sA_tile[row * TILE_SIZE + j2];
if (col0 < tile_k) sum0 += a_val2 * sXt_tile[col0 * (TILE_SIZE+1) + j2];
if (col1 < tile_k) sum1 += a_val2 * sXt_tile[col1 * (TILE_SIZE+1) + j2];
}
}
sum0 = warp_reduce_sum(sum0);
sum1 = warp_reduce_sum(sum1);
if (lane == 0) {
float inv_diag = 1.0f / sA_tile[row * TILE_SIZE + row];
if (col0 < tile_k) {
sXt_tile[col0 * (TILE_SIZE+1) + row] =
(sXt_tile[col0 * (TILE_SIZE+1) + row] - sum0) * inv_diag;
}
if (col1 < tile_k) {
sXt_tile[col1 * (TILE_SIZE+1) + row] =
(sXt_tile[col1 * (TILE_SIZE+1) + row] - sum1) * inv_diag;
}
}
__syncwarp();
}
__syncthreads();
// Write solved tile back to global memory
for (int i = tid; i < tile_n * tile_k; i += 1024) {
int local_row = i / tile_k;
int local_col = i % tile_k;
X_batch[(tile_row + local_row) * k + tile_col + local_col] =
sXt_tile[local_col * (TILE_SIZE+1) + local_row];
}
__syncthreads();
}
}
}
// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
// ======================
// When ncols_template == 0 the bounds for the loops in this function are not
// known and can't be unrolled. As we want to keep pragma unroll for all other
// cases we supress the clang transformation warning here.
@ -88,7 +535,9 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <int n_template, int k_template>
// Template parameters: n_template/k_template are the matrix dimensions when known at compile time (0 = runtime)
// threads_y_template is the number of threads in y dimension (max 32 to stay within 1024 thread limit)
template <int n_template, int k_template, int threads_y_template>
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
@ -103,14 +552,10 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
const int k_arg) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int threads_y = threads_y_template == 0 ? blockDim.y : threads_y_template;
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int col_idx = threadIdx.y;
if (col_idx >= k) {
return;
}
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
@ -121,58 +566,94 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
const int block_threads = blockDim.x * blockDim.y;
// Load A matrix into shared memory
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
const int i0 = i + offset;
for (int i = 0; i < n * n; i += block_threads) {
int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
const int cols_per_thread = (k + threads_y - 1) / threads_y;
// Load B matrix into shared memory (transposed as sXt)
// Each thread handles multiple columns when k > threads_y
for (int c = 0; c < cols_per_thread; c++) {
const int col_idx = threadIdx.y + c * threads_y;
if (col_idx < k) {
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
}
}
__syncthreads();
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
const int half = WARP_SIZE;
const int nrows_low = (n < half) ? n : half;
// Solve for each column this thread handles
for (int c = 0; c < cols_per_thread; c++) {
const int col_idx = threadIdx.y + c * threads_y;
if (col_idx >= k) {
continue;
}
#pragma unroll
for (int row = 0; row < nrows_low; ++row) {
for (int row = 0; row < n; ++row) {
float sum = 0.0f;
if (lane < row) {
sum += sA[row * n + lane] * x_low;
}
sum = warp_reduce_sum(sum);
if (lane == row) {
x_low = (x_low - sum) / sA[row * n + row];
}
}
#pragma unroll
for (int row = half; row < n; ++row) {
float sum = sA[row * n + lane] * x_low;
const int j = half + lane;
{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * x_high;
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
sum = warp_reduce_sum(sum);
if (lane == row - half) {
x_high = (x_high - sum) / sA[row * n + row];
if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
}
}
// Sync between columns to ensure writes are visible
if (c + 1 < cols_per_thread) {
__syncwarp();
}
}
__syncthreads();
// Write results back
for (int c = 0; c < cols_per_thread; c++) {
const int col_idx = threadIdx.y + c * threads_y;
if (col_idx < k) {
#pragma unroll
for (int rr = 0; rr < 2; ++rr) {
const int row = rr * WARP_SIZE + lane;
if (row < n) {
const float val = (row < half) ? x_low : x_high;
X_batch[row * k + col_idx] = val;
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
}
}
}
}
}
@ -180,6 +661,76 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
# pragma clang diagnostic pop
#endif // __clang__
// cuBLAS batched TRSM fallback for larger matrices or as robust path
// Solves A * X = B where A is lower triangular
// This function modifies X in-place (X should be initialized with B)
static void solve_tri_f32_cublas(
ggml_backend_cuda_context & ctx,
const float * A,
float * X, // Input: B, Output: solution X (in-place)
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb2,
size_t nb3,
cudaStream_t stream
) {
const int64_t total_batches = ne02 * ne03;
// Allocate pointer arrays on device
ggml_cuda_pool_alloc<const float *> A_ptrs(ctx.pool(), total_batches);
ggml_cuda_pool_alloc<float *> X_ptrs(ctx.pool(), total_batches);
// Set up pointer arrays on device (CUDA graph compatible)
{
const int block_size = 256;
const int grid_size = (total_batches + block_size - 1) / block_size;
setup_trsm_batch_pointers<<<grid_size, block_size, 0, stream>>>(
A, X,
A_ptrs.get(), X_ptrs.get(),
ne02, total_batches,
nb02, nb03, nb2, nb3
);
CUDA_CHECK(cudaGetLastError());
}
// Get cuBLAS handle and set stream
cublasHandle_t handle = ctx.cublas_handle();
cublasSetStream(handle, stream);
// Save current math mode and set to default for accuracy
// (TF32 can cause numerical issues with triangular solves)
cublasMath_t prev_math_mode;
cublasGetMathMode(handle, &prev_math_mode);
cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);
const float alpha = 1.0f;
cublasStatus_t status = cublasStrsmBatched(
handle,
CUBLAS_SIDE_RIGHT, // A is on the right: X * A = B
CUBLAS_FILL_MODE_UPPER, // A^T is upper (since A is lower in row-major)
CUBLAS_OP_N, // No additional transpose
CUBLAS_DIAG_NON_UNIT, // Diagonal is not assumed to be 1
k, // m: rows of X^T (columns of X)
n, // n: columns of X^T (rows of X) = size of A
&alpha,
(const float **)A_ptrs.get(), n, // lda = n (leading dimension)
(float **)X_ptrs.get(), k, // ldb = k (leading dimension of X^T)
total_batches
);
// Restore previous math mode
cublasSetMathMode(handle, prev_math_mode);
if (status != CUBLAS_STATUS_SUCCESS) {
GGML_LOG_ERROR("cuBLAS batched TRSM failed: %d\n", (int)status);
}
}
static void solve_tri_f32_cuda(const float * A,
const float * B,
float * X,
@ -195,81 +746,133 @@ static void solve_tri_f32_cuda(const float * A,
size_t nb3,
cudaStream_t stream) {
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
dim3 threads(WARP_SIZE, k);
dim3 grid(ne02 * ne03);
// Handle large matrices first (256×256 and 65-128 range)
// Route sizes 65-256 to the tiled kernel
if (n > 64 || k > 64) {
// Use the tiled kernel which works for any size up to 256
// and only requires ~48KB shared memory (within standard limits)
dim3 threads_256(WARP_SIZE, 32); // 1024 threads
// Shared memory: 64×64 + 64×65 + 64×64 = 16KB + 16.25KB + 16KB = ~48KB
const size_t smem_size = (64 * 64 + 64 * 65 + 64 * 64) * sizeof(float);
// Configure extended shared memory for this kernel
static bool smem_configured_tiled = false;
if (!smem_configured_tiled) {
cudaFuncSetAttribute(solve_tri_f32_256x256_tiled,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
smem_configured_tiled = true;
}
solve_tri_f32_256x256_tiled<<<grid, threads_256, smem_size, stream>>>(
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
return;
}
// Limit threads_y to 32 to ensure we don't exceed 1024 threads per block (32 * 32 = 1024)
const int threads_y = k <= 32 ? k : 32;
dim3 threads(WARP_SIZE, threads_y);
if (n == 64) {
switch (k) {
case 64:
{
// Use optimized kernel for n=64, k=64 case (common in Qwen3 Next DeltaNet)
// Block config: 32x32 = 1024 threads (32 warps)
dim3 threads_64x64(WARP_SIZE, 32);
solve_tri_f32_64x64_opt
<<<grid, threads_64x64, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
}
break;
case 48:
// k=48 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2)
solve_tri_f32_fast<64, 48, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 40:
// k=40 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2)
solve_tri_f32_fast<64, 40, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 32:
solve_tri_f32_fast<64, 32>
solve_tri_f32_fast<64, 32, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 16:
solve_tri_f32_fast<64, 16>
solve_tri_f32_fast<64, 16, 16>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 14:
solve_tri_f32_fast<64, 14>
solve_tri_f32_fast<64, 14, 14>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 12:
solve_tri_f32_fast<64, 12>
solve_tri_f32_fast<64, 12, 12>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 10:
solve_tri_f32_fast<64, 10>
solve_tri_f32_fast<64, 10, 10>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 8:
solve_tri_f32_fast<64, 8>
solve_tri_f32_fast<64, 8, 8>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 6:
solve_tri_f32_fast<64, 6>
solve_tri_f32_fast<64, 6, 6>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 4:
solve_tri_f32_fast<64, 4>
solve_tri_f32_fast<64, 4, 4>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 2:
solve_tri_f32_fast<64, 2>
solve_tri_f32_fast<64, 2, 2>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 1:
solve_tri_f32_fast<64, 1>
solve_tri_f32_fast<64, 1, 1>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
default:
solve_tri_f32_fast<0, 0>
solve_tri_f32_fast<0, 0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else { // run general case
solve_tri_f32_fast<0, 0>
solve_tri_f32_fast<0, 0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x n matrix)
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
ggml_is_contiguous(src0);
ggml_is_contiguous(src1);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const int64_t n = src0->ne[0];
const int64_t k = src1->ne[0];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
} else {
solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
}
const int64_t total_batches = src0->ne[2] * src0->ne[3];
const size_t X_size = n * k * total_batches * sizeof(float);
// Copy B to X (cuBLAS solves in-place)
CUDA_CHECK(cudaMemcpyAsync(
dst->data, src1->data, X_size,
cudaMemcpyDeviceToDevice, ctx.stream()
));
solve_tri_f32_cublas(
ctx,
(const float *) src0->data,
(float *) dst->data,
n, k,
src0->ne[2], src0->ne[3],
src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float),
ctx.stream()
);
}

View File

@ -1028,6 +1028,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GATED_LINEAR_ATTN",
"RWKV_WKV7",
"SOLVE_TRI",
"DELTA_NET",
"UNARY",
@ -1045,7 +1046,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1137,6 +1138,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"A X = B, A triangular, solve X",
"delta_net(q, k, v, g, beta, state)",
"unary(x)",
@ -1154,7 +1156,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -6093,6 +6095,63 @@ struct ggml_tensor * ggml_solve_tri(
return result;
}
// delta_net
struct ggml_tensor * ggml_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
GGML_ASSERT(q->type == GGML_TYPE_F32);
GGML_ASSERT(k->type == GGML_TYPE_F32);
GGML_ASSERT(v->type == GGML_TYPE_F32);
GGML_ASSERT(g->type == GGML_TYPE_F32);
GGML_ASSERT(beta->type == GGML_TYPE_F32);
GGML_ASSERT(state->type == GGML_TYPE_F32);
const int64_t S_k = q->ne[0];
const int64_t n_tokens = q->ne[1];
const int64_t H_k = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[2];
GGML_UNUSED(S_k);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs);
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[2] == H_k && g->ne[3] == n_seqs);
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size);
result->op = GGML_OP_DELTA_NET;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
result->src[3] = g;
result->src[4] = beta;
result->src[5] = state;
return result;
}
////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) {

View File

@ -460,6 +460,15 @@ private:
ggml_tensor * diag_mask,
int il);
ggml_tensor * build_delta_net_fused(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
int il);
ggml_tensor * build_delta_net_autoregressive(
ggml_tensor * q,
ggml_tensor * k,

View File

@ -33,12 +33,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
} else {
// Full attention layer
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
}
@ -47,37 +44,28 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// Residual connection
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
// Save the tensor before post-attention norm for residual connection
ggml_tensor * ffn_residual = cur;
// Post-attention norm
ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(attn_post_norm, "attn_post_norm", il);
// FFN layer (MoE or dense) - without residual connection
cur = build_layer_ffn(attn_post_norm, il);
cb(cur, "ffn_out", il);
// Residual connection for FFN - add to the tensor from before post_attention_layernorm
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "post_moe", il);
// Input for next layer
inpL = cur;
}
cur = inpL;
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// LM head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
@ -426,6 +414,78 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
return ggml_concat(ctx0, flat_output, flat_state, 0);
}
ggml_tensor * llm_build_qwen3next::build_delta_net_fused(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(H_k == H_v);
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 1, 3, 0, 2), n_tokens, 1, H_k, n_seqs);
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 1, 2, 0, 3), 1, n_tokens, H_k, n_seqs);
cb(q, "q_fused", il);
cb(k, "k_fused", il);
cb(v, "v_fused", il);
cb(g, "g_fused", il);
cb(beta, "beta_fused", il);
ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state);
cb(fused_result, "delta_net_fused_raw", il);
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
const int64_t state_size = S_v * S_v * H_v * n_seqs;
ggml_tensor * output_4d = ggml_view_4d(ctx0, fused_result,
S_v, H_v, n_tokens, n_seqs,
S_v * ggml_element_size(fused_result),
S_v * H_v * ggml_element_size(fused_result),
S_v * H_v * n_tokens * ggml_element_size(fused_result),
0);
cb(output_4d, "fused_output_4d", il);
ggml_tensor * flat_output = ggml_cont_1d(ctx0, output_4d, output_size);
cb(flat_output, "fused_flat_output", il);
ggml_tensor * flat_state = ggml_view_1d(ctx0, fused_result, state_size,
output_size * ggml_element_size(fused_result));
cb(flat_state, "fused_flat_state", il);
ggml_tensor * result = ggml_concat(ctx0, flat_output, flat_state, 0);
return result;
}
ggml_tensor * llm_build_qwen3next::build_norm_gated(
ggml_tensor * input,
ggml_tensor * weights,
@ -445,16 +505,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
// Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
// Qwen3Next uses a single Q projection that outputs query + gate
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur_full, "Qcur_full", il);
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
// Split Q projection into query and gate
// The split should be along dimension 0 (the feature dimension)
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
ggml_tensor * gate =
@ -463,11 +518,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
cb(Qcur, "Qcur", il);
cb(gate, "gate", il);
// Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur_reshaped", il);
// Apply Q normalization
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
@ -477,18 +530,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
// Apply K normalization
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
// Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
cb(gate, "gate_reshaped", il);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// Apply RoPE
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@ -503,7 +553,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// Attention computation
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp,
@ -737,13 +786,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
ggml_tensor * attn_out;
if (n_seq_tokens == 1) {
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
} else {
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
}
ggml_tensor * attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il);
cb(attn_out, "attn_out", il);
// The tensors were concatenated 1d, so we need to extract them 1d as well
@ -795,9 +838,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
}
ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) {
// Check if this is an MoE layer
if (model.layers[il].ffn_gate_inp != nullptr) {
// MoE branch
ggml_tensor * moe_out =
build_moe_ffn(cur,
model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
@ -807,7 +848,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
cb(moe_out, "ffn_moe_out", il);
// Add shared experts if present - following Qwen3Next reference implementation
if (model.layers[il].ffn_up_shexp != nullptr) {
ggml_tensor * ffn_shexp =
build_ffn(cur,
@ -818,23 +858,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
// Apply shared expert gating as in the reference implementation
// The shared expert has its own gate that is sigmoided
// Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
cb(shared_gate, "shared_expert_gate", il);
// Apply sigmoid to the gate
shared_gate = ggml_sigmoid(ctx0, shared_gate);
cb(shared_gate, "shared_expert_gate_sigmoid", il);
// The gate needs to be broadcast to match the dimensions of ffn_shexp
// ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
// We need to repeat the gate along the feature dimension
shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
cb(shared_gate, "shared_expert_gate_broadcast", il);
// Apply the gate to the shared expert output
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
cb(ffn_shexp, "ffn_shexp_gated", il);
@ -844,7 +876,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
cur = moe_out;
}
} else {
// Dense FFN branch (not currently used I believe)
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,

View File

@ -3550,6 +3550,34 @@ struct test_rwkv_wkv7 : public test_case {
}
};
// GGML_OP_DELTA_NET
struct test_delta_net : public test_case {
const ggml_type type;
const int64_t n_heads;
const int64_t head_dim;
const int64_t n_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, n_heads, head_dim, n_tokens, n_seqs);
}
test_delta_net(ggml_type type = GGML_TYPE_F32,
int64_t n_heads = 8, int64_t head_dim = 64, int64_t n_tokens = 32, int64_t n_seqs = 2)
: type(type), n_heads(n_heads), head_dim(head_dim), n_tokens(n_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, n_tokens, 1, n_heads, n_seqs);
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, n_tokens, n_heads, n_seqs);
ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_dim, head_dim * n_heads, 1, n_seqs);
return ggml_delta_net(ctx, q, k, v, g, beta, state);
}
};
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
@ -7322,6 +7350,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 1, 1));
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 1));
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 2));
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 128, 2));
#if 0
// > 4GB A matrix. Too slow to be enabled by default.
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));