Merge 128a6c2831 into 58062860af
This commit is contained in:
commit
44765b04b3
|
|
@ -551,6 +551,7 @@ extern "C" {
|
||||||
GGML_OP_GATED_LINEAR_ATTN,
|
GGML_OP_GATED_LINEAR_ATTN,
|
||||||
GGML_OP_RWKV_WKV7,
|
GGML_OP_RWKV_WKV7,
|
||||||
GGML_OP_SOLVE_TRI,
|
GGML_OP_SOLVE_TRI,
|
||||||
|
GGML_OP_DELTA_NET,
|
||||||
|
|
||||||
GGML_OP_UNARY,
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
|
|
@ -2460,6 +2461,15 @@ extern "C" {
|
||||||
bool lower,
|
bool lower,
|
||||||
bool uni);
|
bool uni);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_delta_net(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q, // [S_k, n_tokens, H_k, n_seqs] - Query (pre-permuted)
|
||||||
|
struct ggml_tensor * k, // [S_k, n_tokens, H_k, n_seqs] - Key (pre-permuted)
|
||||||
|
struct ggml_tensor * v, // [S_v, n_tokens, H_v, n_seqs] - Value (pre-permuted)
|
||||||
|
struct ggml_tensor * g, // [n_tokens, 1, H_k, n_seqs] - Gate logits (pre-permuted)
|
||||||
|
struct ggml_tensor * beta, // [1, n_tokens, H_k, n_seqs] - Beta (pre-permuted)
|
||||||
|
struct ggml_tensor * state); // [S_v, S_v*H_v, 1, n_seqs] - Recurrent state
|
||||||
|
|
||||||
// custom operators
|
// custom operators
|
||||||
|
|
||||||
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
||||||
|
|
|
||||||
|
|
@ -2014,6 +2014,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_DELTA_NET:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_delta_net(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_SOLVE_TRI:
|
case GGML_OP_SOLVE_TRI:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_solve_tri(params, tensor);
|
ggml_compute_forward_solve_tri(params, tensor);
|
||||||
|
|
@ -2339,6 +2343,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_RWKV_WKV6:
|
case GGML_OP_RWKV_WKV6:
|
||||||
case GGML_OP_GATED_LINEAR_ATTN:
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
case GGML_OP_RWKV_WKV7:
|
case GGML_OP_RWKV_WKV7:
|
||||||
|
case GGML_OP_DELTA_NET:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
|
|
||||||
|
|
@ -10091,6 +10091,139 @@ void ggml_compute_forward_rwkv_wkv7(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_delta_net
|
||||||
|
|
||||||
|
static void ggml_compute_forward_delta_net_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
const ggml_tensor * src2 = dst->src[2];
|
||||||
|
const ggml_tensor * src3 = dst->src[3];
|
||||||
|
const ggml_tensor * src4 = dst->src[4];
|
||||||
|
const ggml_tensor * src5 = dst->src[5];
|
||||||
|
|
||||||
|
const int64_t head_dim = src0->ne[0];
|
||||||
|
const int64_t n_tokens = src0->ne[1];
|
||||||
|
const int64_t n_heads = src0->ne[2];
|
||||||
|
const int64_t n_seqs = src0->ne[3];
|
||||||
|
|
||||||
|
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
|
||||||
|
|
||||||
|
const float * q_data = (const float *) src0->data;
|
||||||
|
const float * k_data = (const float *) src1->data;
|
||||||
|
const float * v_data = (const float *) src2->data;
|
||||||
|
const float * g_data = (const float *) src3->data;
|
||||||
|
const float * beta_data = (const float *) src4->data;
|
||||||
|
const float * state_in = (const float *) src5->data;
|
||||||
|
float * out_data = (float *) dst->data;
|
||||||
|
float * state_out = out_data + output_size;
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t total_heads = n_heads * n_seqs;
|
||||||
|
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
|
||||||
|
const int64_t h_start = ith * heads_per_thread;
|
||||||
|
const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||||
|
|
||||||
|
const float eps = 1e-12f;
|
||||||
|
const float scale = 1.0f / sqrtf((float)head_dim);
|
||||||
|
|
||||||
|
float * v_new_buf = (float *)malloc(head_dim * sizeof(float));
|
||||||
|
if (!v_new_buf) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t h_idx = h_start; h_idx < h_end; h_idx++) {
|
||||||
|
const int64_t batch_idx = h_idx / n_heads;
|
||||||
|
const int64_t head_idx = h_idx % n_heads;
|
||||||
|
|
||||||
|
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||||
|
const int64_t qkv_token_stride = head_dim;
|
||||||
|
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||||
|
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||||
|
const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||||
|
const int64_t out_token_stride = head_dim * n_heads;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < head_dim * head_dim; i++) {
|
||||||
|
state_out[state_head_offset + i] = state_in[state_head_offset + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float * state = state_out + state_head_offset;
|
||||||
|
|
||||||
|
for (int64_t t = 0; t < n_tokens; t++) {
|
||||||
|
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||||
|
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||||
|
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||||
|
|
||||||
|
float g_val = g_data[g_head_offset + t];
|
||||||
|
float beta_raw = beta_data[g_head_offset + t];
|
||||||
|
|
||||||
|
float q_norm_sq = 0.0f, k_norm_sq = 0.0f;
|
||||||
|
for (int64_t i = 0; i < head_dim; i++) {
|
||||||
|
q_norm_sq += q_t[i] * q_t[i];
|
||||||
|
k_norm_sq += k_t[i] * k_t[i];
|
||||||
|
}
|
||||||
|
float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
|
||||||
|
float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
|
||||||
|
|
||||||
|
float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||||
|
float decay = expf(fminf(g_val, 50.0f));
|
||||||
|
|
||||||
|
float attn_score = 0.0f;
|
||||||
|
for (int64_t i = 0; i < head_dim; i++) {
|
||||||
|
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||||
|
|
||||||
|
for (int64_t row = 0; row < head_dim; row++) {
|
||||||
|
float v_prime = 0.0f;
|
||||||
|
float out_val = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t col = 0; col < head_dim; col++) {
|
||||||
|
float k_col = k_t[col] * k_norm_inv;
|
||||||
|
float q_col = q_t[col] * q_norm_inv * scale;
|
||||||
|
float s = state[row + col * head_dim];
|
||||||
|
|
||||||
|
v_prime += s * k_col * beta_val * decay;
|
||||||
|
out_val += s * q_col * decay;
|
||||||
|
}
|
||||||
|
|
||||||
|
float v_new = v_t[row] * beta_val - v_prime;
|
||||||
|
v_new_buf[row] = v_new;
|
||||||
|
out_t[row] = out_val + v_new * attn_score;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t col = 0; col < head_dim; col++) {
|
||||||
|
float k_col = k_t[col] * k_norm_inv;
|
||||||
|
for (int64_t row = 0; row < head_dim; row++) {
|
||||||
|
float s = state[row + col * head_dim];
|
||||||
|
s = decay * s + v_new_buf[row] * k_col;
|
||||||
|
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
free(v_new_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_delta_net(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
ggml_compute_forward_delta_net_f32(params, dst);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_map_custom1
|
// ggml_compute_forward_map_custom1
|
||||||
|
|
||||||
void ggml_compute_forward_map_custom1(
|
void ggml_compute_forward_map_custom1(
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
|
||||||
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,3 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
@ -55,6 +55,7 @@
|
||||||
#include "ggml-cuda/set-rows.cuh"
|
#include "ggml-cuda/set-rows.cuh"
|
||||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||||
#include "ggml-cuda/solve_tri.cuh"
|
#include "ggml-cuda/solve_tri.cuh"
|
||||||
|
#include "ggml-cuda/delta-net.cuh"
|
||||||
#include "ggml-cuda/tri.cuh"
|
#include "ggml-cuda/tri.cuh"
|
||||||
#include "ggml-cuda/cumsum.cuh"
|
#include "ggml-cuda/cumsum.cuh"
|
||||||
#include "ggml-cuda/fill.cuh"
|
#include "ggml-cuda/fill.cuh"
|
||||||
|
|
@ -2735,6 +2736,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_SOLVE_TRI:
|
case GGML_OP_SOLVE_TRI:
|
||||||
ggml_cuda_op_solve_tri(ctx, dst);
|
ggml_cuda_op_solve_tri(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_DELTA_NET:
|
||||||
|
ggml_cuda_op_delta_net(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_FILL:
|
case GGML_OP_FILL:
|
||||||
ggml_cuda_op_fill(ctx, dst);
|
ggml_cuda_op_fill(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
|
@ -2904,6 +2908,13 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node->op == GGML_OP_DELTA_NET) {
|
||||||
|
use_cuda_graph = false;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to DELTA_NET recurrent state\n", __func__);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
if (!use_cuda_graph) {
|
if (!use_cuda_graph) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -4632,6 +4643,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_DIAG:
|
case GGML_OP_DIAG:
|
||||||
case GGML_OP_SOLVE_TRI:
|
case GGML_OP_SOLVE_TRI:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_DELTA_NET:
|
||||||
|
return op->src[0]->ne[0] <= 256 && op->src[2]->ne[0] <= 256;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
|
||||||
|
|
@ -1,86 +1,533 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "solve_tri.cuh"
|
#include "solve_tri.cuh"
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
|
||||||
#define MAX_N_FAST 64
|
#define MAX_N_FAST 64
|
||||||
#define MAX_K_FAST 32
|
#define MAX_K_FAST 64
|
||||||
|
|
||||||
static __global__ void get_batch_pointers(const float * A,
|
// Kernel to set up pointer arrays for batched cuBLAS TRSM
|
||||||
float * X,
|
// This avoids host-device copy during CUDA graph capture
|
||||||
const float ** A_ptrs,
|
static __global__ void setup_trsm_batch_pointers(
|
||||||
float ** X_ptrs,
|
const float * A,
|
||||||
int64_t ne02,
|
float * X,
|
||||||
int64_t total_batches,
|
const float ** A_ptrs,
|
||||||
size_t s02,
|
float ** X_ptrs,
|
||||||
size_t s03,
|
const int64_t ne02,
|
||||||
size_t s2,
|
const int64_t total_batches,
|
||||||
size_t s3) {
|
const size_t nb02, // stride for A dim 2 (in floats)
|
||||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t nb03, // stride for A dim 3 (in floats)
|
||||||
if (idx >= total_batches) {
|
const size_t nb2, // stride for X dim 2 (in floats)
|
||||||
return;
|
const size_t nb3 // stride for X dim 3 (in floats)
|
||||||
}
|
) {
|
||||||
|
const int64_t batch_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (batch_idx >= total_batches) return;
|
||||||
|
|
||||||
const int64_t i3 = idx / ne02;
|
// Decompose batch_idx into i02, i03
|
||||||
const int64_t i2 = idx % ne02;
|
const int64_t i02 = batch_idx % ne02;
|
||||||
|
const int64_t i03 = batch_idx / ne02;
|
||||||
|
|
||||||
A_ptrs[idx] = A + i3 * s03 + i2 * s02;
|
A_ptrs[batch_idx] = A + i02 * nb02 + i03 * nb03;
|
||||||
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
|
X_ptrs[batch_idx] = X + i02 * nb2 + i03 * nb3;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
|
// Latency-optimized kernel for n=64, k=64 (single-token generation)
|
||||||
const float * A,
|
static __global__ void solve_tri_f32_64x64_latency(
|
||||||
const float * B,
|
const float * __restrict__ A,
|
||||||
float * X,
|
const float * __restrict__ B,
|
||||||
int n,
|
float * __restrict__ X,
|
||||||
int k,
|
const uint3 ne02,
|
||||||
int64_t ne02,
|
const size_t nb02,
|
||||||
int64_t ne03,
|
const size_t nb03,
|
||||||
size_t s02,
|
const size_t nb12,
|
||||||
size_t s03,
|
const size_t nb13,
|
||||||
size_t s12,
|
const size_t nb2,
|
||||||
size_t s13,
|
const size_t nb3)
|
||||||
size_t s2,
|
{
|
||||||
size_t s3,
|
const int batch_idx = blockIdx.x;
|
||||||
cudaStream_t stream) {
|
const int lane = threadIdx.x;
|
||||||
const float alpha = 1.0f;
|
const int warp_id = threadIdx.y;
|
||||||
const int64_t total_batches = ne02 * ne03;
|
|
||||||
if (total_batches == 0) {
|
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||||||
return;
|
const int64_t i02 = i02_i03.y;
|
||||||
|
const int64_t i03 = i02_i03.x;
|
||||||
|
|
||||||
|
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
||||||
|
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
||||||
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
// Shared memory: A is 64x64, X is 64x65 (padded for bank conflicts)
|
||||||
|
__shared__ float sA[64 * 64];
|
||||||
|
__shared__ float sX[64 * 65];
|
||||||
|
__shared__ float sDiagInv[64]; // Precomputed 1/diagonal
|
||||||
|
|
||||||
|
const int tid = lane + warp_id * WARP_SIZE;
|
||||||
|
|
||||||
|
// Cooperative load of A matrix (4096 elements / 512 threads = 8 per thread)
|
||||||
|
#pragma unroll 8
|
||||||
|
for (int i = tid; i < 64 * 64; i += 512) {
|
||||||
|
sA[i] = A_batch[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bulk copy B -> X (contiguous tensors)
|
// Cooperative load of B matrix into sX with padding
|
||||||
if (X != B) {
|
#pragma unroll 8
|
||||||
const int64_t total_elements_BX = n * k * total_batches;
|
for (int i = tid; i < 64 * 64; i += 512) {
|
||||||
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
const int row = i / 64;
|
||||||
|
const int col = i % 64;
|
||||||
|
sX[row * 65 + col] = B_batch[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
const int id = ggml_cuda_get_device();
|
__syncthreads();
|
||||||
|
|
||||||
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
|
// Precompute diagonal inverses (first 2 warps handle this)
|
||||||
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
|
if (warp_id == 0) {
|
||||||
|
if (lane < 32) {
|
||||||
|
sDiagInv[lane] = 1.0f / sA[lane * 64 + lane];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (warp_id == 1) {
|
||||||
|
if (lane < 32) {
|
||||||
|
sDiagInv[32 + lane] = 1.0f / sA[(32 + lane) * 64 + (32 + lane)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const float ** A_ptrs_dev = A_ptrs_alloc.get();
|
__syncthreads();
|
||||||
float ** X_ptrs_dev = X_ptrs_alloc.get();
|
|
||||||
|
|
||||||
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
|
// Each warp handles 4 columns: cols = warp_id*4 to warp_id*4+3
|
||||||
total_batches, s02, s03, s2, s3);
|
const int col_base = warp_id * 4;
|
||||||
|
|
||||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
#pragma unroll 1
|
||||||
|
for (int row = 0; row < 64; ++row) {
|
||||||
|
float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
|
||||||
|
|
||||||
// Yes, this is necessary, without this we get RMSE errors
|
if (row > 0) {
|
||||||
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
|
for (int j = lane; j < row; j += WARP_SIZE) {
|
||||||
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
|
const float a_val = sA[row * 64 + j];
|
||||||
CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
|
sum0 += a_val * sX[j * 65 + col_base + 0];
|
||||||
|
sum1 += a_val * sX[j * 65 + col_base + 1];
|
||||||
|
sum2 += a_val * sX[j * 65 + col_base + 2];
|
||||||
|
sum3 += a_val * sX[j * 65 + col_base + 3];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// revert to standard mode from common.cuh
|
sum0 = warp_reduce_sum(sum0);
|
||||||
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
|
sum1 = warp_reduce_sum(sum1);
|
||||||
|
sum2 = warp_reduce_sum(sum2);
|
||||||
|
sum3 = warp_reduce_sum(sum3);
|
||||||
|
|
||||||
GGML_UNUSED_VARS(s12, s13);
|
if (lane == 0) {
|
||||||
|
const float inv_diag = sDiagInv[row];
|
||||||
|
sX[row * 65 + col_base + 0] = (sX[row * 65 + col_base + 0] - sum0) * inv_diag;
|
||||||
|
sX[row * 65 + col_base + 1] = (sX[row * 65 + col_base + 1] - sum1) * inv_diag;
|
||||||
|
sX[row * 65 + col_base + 2] = (sX[row * 65 + col_base + 2] - sum2) * inv_diag;
|
||||||
|
sX[row * 65 + col_base + 3] = (sX[row * 65 + col_base + 3] - sum3) * inv_diag;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cooperative write results back
|
||||||
|
#pragma unroll 8
|
||||||
|
for (int i = tid; i < 64 * 64; i += 512) {
|
||||||
|
const int row = i / 64;
|
||||||
|
const int col = i % 64;
|
||||||
|
X_batch[i] = sX[row * 65 + col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void solve_tri_f32_64x64_opt(const float * __restrict__ A,
|
||||||
|
const float * __restrict__ B,
|
||||||
|
float * __restrict__ X,
|
||||||
|
const uint3 ne02,
|
||||||
|
const size_t nb02,
|
||||||
|
const size_t nb03,
|
||||||
|
const size_t nb12,
|
||||||
|
const size_t nb13,
|
||||||
|
const size_t nb2,
|
||||||
|
const size_t nb3) {
|
||||||
|
const int batch_idx = blockIdx.x;
|
||||||
|
const int lane = threadIdx.x;
|
||||||
|
const int warp_id = threadIdx.y;
|
||||||
|
|
||||||
|
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||||||
|
const int64_t i02 = i02_i03.y;
|
||||||
|
const int64_t i03 = i02_i03.x;
|
||||||
|
|
||||||
|
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
||||||
|
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
||||||
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
// Shared memory: A is 64x64, sXt is 64x65 (padded)
|
||||||
|
__shared__ float sA[64 * 64];
|
||||||
|
__shared__ float sXt[64 * 65];
|
||||||
|
|
||||||
|
const int tid = lane + warp_id * WARP_SIZE;
|
||||||
|
|
||||||
|
// Cooperative load of A matrix (4096 elements / 1024 threads = 4 per thread)
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int i = tid; i < 64 * 64; i += 1024) {
|
||||||
|
sA[i] = A_batch[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cooperative load of B matrix transposed into sXt
|
||||||
|
// sXt[col * 65 + row] = B[row * 64 + col]
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int i = tid; i < 64 * 64; i += 1024) {
|
||||||
|
const int row = i / 64;
|
||||||
|
const int col = i % 64;
|
||||||
|
sXt[col * 65 + row] = B_batch[row * 64 + col];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Each warp handles 2 columns: col0 = warp_id*2, col1 = warp_id*2 + 1
|
||||||
|
const int col0 = warp_id * 2;
|
||||||
|
const int col1 = warp_id * 2 + 1;
|
||||||
|
|
||||||
|
// Forward substitution with all columns processed in parallel
|
||||||
|
// Each row depends on previous rows, but different columns are independent
|
||||||
|
#pragma unroll 1
|
||||||
|
for (int row = 0; row < 64; ++row) {
|
||||||
|
// Each lane computes partial sum for indices it handles
|
||||||
|
float sum0 = 0.0f;
|
||||||
|
float sum1 = 0.0f;
|
||||||
|
|
||||||
|
// Sum over j < row
|
||||||
|
// For row <= 32: each lane handles at most 1 element
|
||||||
|
// For row > 32: each lane handles at most 2 elements
|
||||||
|
if (lane < row) {
|
||||||
|
const float a_val = sA[row * 64 + lane];
|
||||||
|
sum0 = a_val * sXt[col0 * 65 + lane];
|
||||||
|
sum1 = a_val * sXt[col1 * 65 + lane];
|
||||||
|
}
|
||||||
|
if (row > WARP_SIZE) {
|
||||||
|
const int j2 = lane + WARP_SIZE;
|
||||||
|
if (j2 < row) {
|
||||||
|
const float a_val2 = sA[row * 64 + j2];
|
||||||
|
sum0 += a_val2 * sXt[col0 * 65 + j2];
|
||||||
|
sum1 += a_val2 * sXt[col1 * 65 + j2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp-level reduction
|
||||||
|
sum0 = warp_reduce_sum(sum0);
|
||||||
|
sum1 = warp_reduce_sum(sum1);
|
||||||
|
|
||||||
|
// Lane 0 computes and stores the result
|
||||||
|
if (lane == 0) {
|
||||||
|
const float a_diag = sA[row * 64 + row];
|
||||||
|
const float inv_diag = 1.0f / a_diag;
|
||||||
|
sXt[col0 * 65 + row] = (sXt[col0 * 65 + row] - sum0) * inv_diag;
|
||||||
|
sXt[col1 * 65 + row] = (sXt[col1 * 65 + row] - sum1) * inv_diag;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync within warp to ensure writes are visible before next row reads
|
||||||
|
__syncwarp();
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Cooperative write of results back (transpose sXt to X)
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int i = tid; i < 64 * 64; i += 1024) {
|
||||||
|
const int row = i / 64;
|
||||||
|
const int col = i % 64;
|
||||||
|
X_batch[row * 64 + col] = sXt[col * 65 + row];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void solve_tri_f32_128x128_opt(const float * __restrict__ A,
|
||||||
|
const float * __restrict__ B,
|
||||||
|
float * __restrict__ X,
|
||||||
|
const uint3 ne02,
|
||||||
|
const size_t nb02,
|
||||||
|
const size_t nb03,
|
||||||
|
const size_t nb12,
|
||||||
|
const size_t nb13,
|
||||||
|
const size_t nb2,
|
||||||
|
const size_t nb3,
|
||||||
|
const int n,
|
||||||
|
const int k) {
|
||||||
|
const int batch_idx = blockIdx.x;
|
||||||
|
const int lane = threadIdx.x;
|
||||||
|
const int warp_id = threadIdx.y;
|
||||||
|
|
||||||
|
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||||||
|
const int64_t i02 = i02_i03.y;
|
||||||
|
const int64_t i03 = i02_i03.x;
|
||||||
|
|
||||||
|
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
||||||
|
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
||||||
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
// Shared memory with padding to avoid bank conflicts
|
||||||
|
// Layout: sA[128][128] + sXt[128][129]
|
||||||
|
extern __shared__ char smem_raw[];
|
||||||
|
float * sA = (float *)smem_raw; // 128×128 (zero-initialized for unused parts)
|
||||||
|
float * sXt = sA + 128 * 128; // 128×129 (padded)
|
||||||
|
|
||||||
|
const int tid = lane + warp_id * WARP_SIZE;
|
||||||
|
|
||||||
|
// Zero-initialize shared memory first (important for variable n, k)
|
||||||
|
#pragma unroll 16
|
||||||
|
for (int i = tid; i < 128 * 128; i += 1024) {
|
||||||
|
sA[i] = 0.0f;
|
||||||
|
}
|
||||||
|
#pragma unroll 16
|
||||||
|
for (int i = tid; i < 128 * 129; i += 1024) {
|
||||||
|
sXt[i] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Cooperative load of A matrix (n×n elements)
|
||||||
|
for (int i = tid; i < n * n; i += 1024) {
|
||||||
|
const int row = i / n;
|
||||||
|
const int col = i % n;
|
||||||
|
sA[row * 128 + col] = A_batch[row * n + col];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cooperative load of B matrix transposed into sXt
|
||||||
|
// sXt[col * 129 + row] = B[row * k + col]
|
||||||
|
for (int i = tid; i < n * k; i += 1024) {
|
||||||
|
const int row = i / k;
|
||||||
|
const int col = i % k;
|
||||||
|
sXt[col * 129 + row] = B_batch[row * k + col];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Each warp handles columns: col_base to col_base+3
|
||||||
|
// But only process if col < k
|
||||||
|
const int col_base = warp_id * 4;
|
||||||
|
|
||||||
|
// Forward substitution with all columns processed in parallel
|
||||||
|
for (int row = 0; row < n; ++row) {
|
||||||
|
float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
|
||||||
|
|
||||||
|
// Sum over j < row - each lane handles multiple elements
|
||||||
|
for (int j = lane; j < row; j += WARP_SIZE) {
|
||||||
|
const float a_val = sA[row * 128 + j];
|
||||||
|
if (col_base + 0 < k) sum0 += a_val * sXt[(col_base + 0) * 129 + j];
|
||||||
|
if (col_base + 1 < k) sum1 += a_val * sXt[(col_base + 1) * 129 + j];
|
||||||
|
if (col_base + 2 < k) sum2 += a_val * sXt[(col_base + 2) * 129 + j];
|
||||||
|
if (col_base + 3 < k) sum3 += a_val * sXt[(col_base + 3) * 129 + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp-level reduction
|
||||||
|
sum0 = warp_reduce_sum(sum0);
|
||||||
|
sum1 = warp_reduce_sum(sum1);
|
||||||
|
sum2 = warp_reduce_sum(sum2);
|
||||||
|
sum3 = warp_reduce_sum(sum3);
|
||||||
|
|
||||||
|
// Lane 0 computes and stores the result
|
||||||
|
if (lane == 0) {
|
||||||
|
const float inv_diag = 1.0f / sA[row * 128 + row];
|
||||||
|
if (col_base + 0 < k) {
|
||||||
|
sXt[(col_base + 0) * 129 + row] = (sXt[(col_base + 0) * 129 + row] - sum0) * inv_diag;
|
||||||
|
}
|
||||||
|
if (col_base + 1 < k) {
|
||||||
|
sXt[(col_base + 1) * 129 + row] = (sXt[(col_base + 1) * 129 + row] - sum1) * inv_diag;
|
||||||
|
}
|
||||||
|
if (col_base + 2 < k) {
|
||||||
|
sXt[(col_base + 2) * 129 + row] = (sXt[(col_base + 2) * 129 + row] - sum2) * inv_diag;
|
||||||
|
}
|
||||||
|
if (col_base + 3 < k) {
|
||||||
|
sXt[(col_base + 3) * 129 + row] = (sXt[(col_base + 3) * 129 + row] - sum3) * inv_diag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncwarp();
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Cooperative write of results back (transpose sXt to X)
|
||||||
|
for (int i = tid; i < n * k; i += 1024) {
|
||||||
|
const int row = i / k;
|
||||||
|
const int col = i % k;
|
||||||
|
X_batch[row * k + col] = sXt[col * 129 + row];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void solve_tri_f32_256x256_tiled(const float * __restrict__ A,
|
||||||
|
const float * __restrict__ B,
|
||||||
|
float * __restrict__ X,
|
||||||
|
const uint3 ne02,
|
||||||
|
const size_t nb02,
|
||||||
|
const size_t nb03,
|
||||||
|
const size_t nb12,
|
||||||
|
const size_t nb13,
|
||||||
|
const size_t nb2,
|
||||||
|
const size_t nb3,
|
||||||
|
const int n,
|
||||||
|
const int k) {
|
||||||
|
const int batch_idx = blockIdx.x;
|
||||||
|
const int lane = threadIdx.x;
|
||||||
|
const int warp_id = threadIdx.y;
|
||||||
|
|
||||||
|
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||||||
|
const int64_t i02 = i02_i03.y;
|
||||||
|
const int64_t i03 = i02_i03.x;
|
||||||
|
|
||||||
|
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
||||||
|
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
||||||
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
// Tiled approach using 64×64 tiles to fit in shared memory
|
||||||
|
constexpr int TILE_SIZE = 64;
|
||||||
|
|
||||||
|
extern __shared__ char smem_raw[];
|
||||||
|
float * sA_tile = (float *)smem_raw; // 64×64 = 16KB
|
||||||
|
float * sXt_tile = sA_tile + TILE_SIZE * TILE_SIZE; // 64×65 = 16.25KB (padded)
|
||||||
|
float * sA_off = sXt_tile + TILE_SIZE * (TILE_SIZE+1); // 64×64 = 16KB (for off-diagonal blocks)
|
||||||
|
|
||||||
|
const int tid = lane + warp_id * WARP_SIZE;
|
||||||
|
|
||||||
|
// Initialize X = B (we'll solve in-place conceptually, using global memory)
|
||||||
|
for (int i = tid; i < n * k; i += 1024) {
|
||||||
|
X_batch[i] = B_batch[i];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Process tile-by-tile along the diagonal
|
||||||
|
for (int tile_row = 0; tile_row < n; tile_row += TILE_SIZE) {
|
||||||
|
const int tile_n = min(TILE_SIZE, n - tile_row); // Actual rows in this tile
|
||||||
|
|
||||||
|
// Zero-init and load diagonal tile of A
|
||||||
|
for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) {
|
||||||
|
sA_tile[i] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int i = tid; i < tile_n * tile_n; i += 1024) {
|
||||||
|
int local_row = i / tile_n;
|
||||||
|
int local_col = i % tile_n;
|
||||||
|
sA_tile[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + tile_row + local_col];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// For each column tile of X
|
||||||
|
for (int tile_col = 0; tile_col < k; tile_col += TILE_SIZE) {
|
||||||
|
const int tile_k = min(TILE_SIZE, k - tile_col); // Actual columns in this tile
|
||||||
|
|
||||||
|
// Zero-init and load X tile transposed
|
||||||
|
for (int i = tid; i < TILE_SIZE * (TILE_SIZE+1); i += 1024) {
|
||||||
|
sXt_tile[i] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int i = tid; i < tile_n * tile_k; i += 1024) {
|
||||||
|
int local_row = i / tile_k;
|
||||||
|
int local_col = i % tile_k;
|
||||||
|
sXt_tile[local_col * (TILE_SIZE+1) + local_row] =
|
||||||
|
X_batch[(tile_row + local_row) * k + tile_col + local_col];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Apply updates from previous tile rows
|
||||||
|
for (int prev_tile = 0; prev_tile < tile_row; prev_tile += TILE_SIZE) {
|
||||||
|
const int prev_n = min(TILE_SIZE, n - prev_tile);
|
||||||
|
|
||||||
|
// Zero-init and load off-diagonal block
|
||||||
|
for (int i = tid; i < TILE_SIZE * TILE_SIZE; i += 1024) {
|
||||||
|
sA_off[i] = 0.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int i = tid; i < tile_n * prev_n; i += 1024) {
|
||||||
|
int local_row = i / prev_n;
|
||||||
|
int local_col = i % prev_n;
|
||||||
|
sA_off[local_row * TILE_SIZE + local_col] = A_batch[(tile_row + local_row) * n + prev_tile + local_col];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Update: X_tile -= A_off @ X_prev
|
||||||
|
int col0 = warp_id * 2;
|
||||||
|
int col1 = warp_id * 2 + 1;
|
||||||
|
|
||||||
|
for (int row = 0; row < tile_n; row++) {
|
||||||
|
float sum0 = 0.0f, sum1 = 0.0f;
|
||||||
|
|
||||||
|
for (int j = lane; j < prev_n; j += WARP_SIZE) {
|
||||||
|
float a_val = sA_off[row * TILE_SIZE + j];
|
||||||
|
if (col0 < tile_k) {
|
||||||
|
float x_prev0 = X_batch[(prev_tile + j) * k + tile_col + col0];
|
||||||
|
sum0 += a_val * x_prev0;
|
||||||
|
}
|
||||||
|
if (col1 < tile_k) {
|
||||||
|
float x_prev1 = X_batch[(prev_tile + j) * k + tile_col + col1];
|
||||||
|
sum1 += a_val * x_prev1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum0 = warp_reduce_sum(sum0);
|
||||||
|
sum1 = warp_reduce_sum(sum1);
|
||||||
|
|
||||||
|
if (lane == 0) {
|
||||||
|
if (col0 < tile_k) {
|
||||||
|
sXt_tile[col0 * (TILE_SIZE+1) + row] -= sum0;
|
||||||
|
}
|
||||||
|
if (col1 < tile_k) {
|
||||||
|
sXt_tile[col1 * (TILE_SIZE+1) + row] -= sum1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Solve the diagonal tile
|
||||||
|
int col0 = warp_id * 2;
|
||||||
|
int col1 = warp_id * 2 + 1;
|
||||||
|
|
||||||
|
for (int row = 0; row < tile_n; ++row) {
|
||||||
|
float sum0 = 0.0f, sum1 = 0.0f;
|
||||||
|
|
||||||
|
if (lane < row) {
|
||||||
|
float a_val = sA_tile[row * TILE_SIZE + lane];
|
||||||
|
if (col0 < tile_k) sum0 = a_val * sXt_tile[col0 * (TILE_SIZE+1) + lane];
|
||||||
|
if (col1 < tile_k) sum1 = a_val * sXt_tile[col1 * (TILE_SIZE+1) + lane];
|
||||||
|
}
|
||||||
|
if (row > WARP_SIZE) {
|
||||||
|
int j2 = lane + WARP_SIZE;
|
||||||
|
if (j2 < row) {
|
||||||
|
float a_val2 = sA_tile[row * TILE_SIZE + j2];
|
||||||
|
if (col0 < tile_k) sum0 += a_val2 * sXt_tile[col0 * (TILE_SIZE+1) + j2];
|
||||||
|
if (col1 < tile_k) sum1 += a_val2 * sXt_tile[col1 * (TILE_SIZE+1) + j2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum0 = warp_reduce_sum(sum0);
|
||||||
|
sum1 = warp_reduce_sum(sum1);
|
||||||
|
|
||||||
|
if (lane == 0) {
|
||||||
|
float inv_diag = 1.0f / sA_tile[row * TILE_SIZE + row];
|
||||||
|
if (col0 < tile_k) {
|
||||||
|
sXt_tile[col0 * (TILE_SIZE+1) + row] =
|
||||||
|
(sXt_tile[col0 * (TILE_SIZE+1) + row] - sum0) * inv_diag;
|
||||||
|
}
|
||||||
|
if (col1 < tile_k) {
|
||||||
|
sXt_tile[col1 * (TILE_SIZE+1) + row] =
|
||||||
|
(sXt_tile[col1 * (TILE_SIZE+1) + row] - sum1) * inv_diag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncwarp();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Write solved tile back to global memory
|
||||||
|
for (int i = tid; i < tile_n * tile_k; i += 1024) {
|
||||||
|
int local_row = i / tile_k;
|
||||||
|
int local_col = i % tile_k;
|
||||||
|
X_batch[(tile_row + local_row) * k + tile_col + local_col] =
|
||||||
|
sXt_tile[local_col * (TILE_SIZE+1) + local_row];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ======================
|
|
||||||
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
|
||||||
// ======================
|
|
||||||
// When ncols_template == 0 the bounds for the loops in this function are not
|
// When ncols_template == 0 the bounds for the loops in this function are not
|
||||||
// known and can't be unrolled. As we want to keep pragma unroll for all other
|
// known and can't be unrolled. As we want to keep pragma unroll for all other
|
||||||
// cases we supress the clang transformation warning here.
|
// cases we supress the clang transformation warning here.
|
||||||
|
|
@ -88,7 +535,9 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
|
||||||
# pragma clang diagnostic push
|
# pragma clang diagnostic push
|
||||||
# pragma clang diagnostic ignored "-Wpass-failed"
|
# pragma clang diagnostic ignored "-Wpass-failed"
|
||||||
#endif // __clang__
|
#endif // __clang__
|
||||||
template <int n_template, int k_template>
|
// Template parameters: n_template/k_template are the matrix dimensions when known at compile time (0 = runtime)
|
||||||
|
// threads_y_template is the number of threads in y dimension (max 32 to stay within 1024 thread limit)
|
||||||
|
template <int n_template, int k_template, int threads_y_template>
|
||||||
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
||||||
const float * __restrict__ B,
|
const float * __restrict__ B,
|
||||||
float * __restrict__ X,
|
float * __restrict__ X,
|
||||||
|
|
@ -103,14 +552,10 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
||||||
const int k_arg) {
|
const int k_arg) {
|
||||||
const int n = n_template == 0 ? n_arg : n_template;
|
const int n = n_template == 0 ? n_arg : n_template;
|
||||||
const int k = k_template == 0 ? k_arg : k_template;
|
const int k = k_template == 0 ? k_arg : k_template;
|
||||||
|
const int threads_y = threads_y_template == 0 ? blockDim.y : threads_y_template;
|
||||||
|
|
||||||
const int batch_idx = blockIdx.x;
|
const int batch_idx = blockIdx.x;
|
||||||
const int lane = threadIdx.x;
|
const int lane = threadIdx.x;
|
||||||
const int col_idx = threadIdx.y;
|
|
||||||
|
|
||||||
if (col_idx >= k) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||||||
const int64_t i02 = i02_i03.y;
|
const int64_t i02 = i02_i03.y;
|
||||||
|
|
@ -121,58 +566,94 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
||||||
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
||||||
|
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
|
||||||
|
|
||||||
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
const int block_threads = blockDim.x * blockDim.y;
|
||||||
|
|
||||||
|
// Load A matrix into shared memory
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
for (int i = 0; i < n * n; i += block_threads) {
|
||||||
const int i0 = i + offset;
|
int i0 = i + offset;
|
||||||
if (i0 < n * n) {
|
if (i0 < n * n) {
|
||||||
sA[i0] = A_batch[i0];
|
sA[i0] = A_batch[i0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
const int cols_per_thread = (k + threads_y - 1) / threads_y;
|
||||||
|
|
||||||
|
// Load B matrix into shared memory (transposed as sXt)
|
||||||
|
// Each thread handles multiple columns when k > threads_y
|
||||||
|
for (int c = 0; c < cols_per_thread; c++) {
|
||||||
|
const int col_idx = threadIdx.y + c * threads_y;
|
||||||
|
if (col_idx < k) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < rows_per_warp; i++) {
|
||||||
|
const int i0 = lane + i * WARP_SIZE;
|
||||||
|
if (i0 < n) {
|
||||||
|
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
|
// Solve for each column this thread handles
|
||||||
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
|
for (int c = 0; c < cols_per_thread; c++) {
|
||||||
|
const int col_idx = threadIdx.y + c * threads_y;
|
||||||
const int half = WARP_SIZE;
|
if (col_idx >= k) {
|
||||||
const int nrows_low = (n < half) ? n : half;
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int row = 0; row < nrows_low; ++row) {
|
for (int row = 0; row < n; ++row) {
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
if (lane < row) {
|
|
||||||
sum += sA[row * n + lane] * x_low;
|
|
||||||
}
|
|
||||||
sum = warp_reduce_sum(sum);
|
|
||||||
|
|
||||||
if (lane == row) {
|
{
|
||||||
x_low = (x_low - sum) / sA[row * n + row];
|
int j = lane;
|
||||||
|
if (j < row) {
|
||||||
|
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (row >= WARP_SIZE) {
|
||||||
|
int j = WARP_SIZE + lane;
|
||||||
|
if (j < row) {
|
||||||
|
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (lane == 0) {
|
||||||
|
const float b_val = sXt[col_idx * n + row];
|
||||||
|
const float a_diag = sA[row * n + row];
|
||||||
|
// no safeguards for division by zero because that indicates corrupt
|
||||||
|
// data anyway
|
||||||
|
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync between columns to ensure writes are visible
|
||||||
|
if (c + 1 < cols_per_thread) {
|
||||||
|
__syncwarp();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
__syncthreads();
|
||||||
for (int row = half; row < n; ++row) {
|
|
||||||
float sum = sA[row * n + lane] * x_low;
|
|
||||||
const int j = half + lane;
|
|
||||||
if (j < row) {
|
|
||||||
sum += sA[row * n + j] * x_high;
|
|
||||||
}
|
|
||||||
sum = warp_reduce_sum(sum);
|
|
||||||
|
|
||||||
if (lane == row - half) {
|
|
||||||
x_high = (x_high - sum) / sA[row * n + row];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Write results back
|
||||||
|
for (int c = 0; c < cols_per_thread; c++) {
|
||||||
|
const int col_idx = threadIdx.y + c * threads_y;
|
||||||
|
if (col_idx < k) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int rr = 0; rr < 2; ++rr) {
|
for (int i = 0; i < rows_per_warp; i++) {
|
||||||
const int row = rr * WARP_SIZE + lane;
|
const int i0 = lane + i * WARP_SIZE;
|
||||||
if (row < n) {
|
if (i0 < n) {
|
||||||
const float val = (row < half) ? x_low : x_high;
|
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
|
||||||
X_batch[row * k + col_idx] = val;
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -180,6 +661,76 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
||||||
# pragma clang diagnostic pop
|
# pragma clang diagnostic pop
|
||||||
#endif // __clang__
|
#endif // __clang__
|
||||||
|
|
||||||
|
// cuBLAS batched TRSM fallback for larger matrices or as robust path
|
||||||
|
// Solves A * X = B where A is lower triangular
|
||||||
|
// This function modifies X in-place (X should be initialized with B)
|
||||||
|
static void solve_tri_f32_cublas(
|
||||||
|
ggml_backend_cuda_context & ctx,
|
||||||
|
const float * A,
|
||||||
|
float * X, // Input: B, Output: solution X (in-place)
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
int64_t ne02,
|
||||||
|
int64_t ne03,
|
||||||
|
size_t nb02,
|
||||||
|
size_t nb03,
|
||||||
|
size_t nb2,
|
||||||
|
size_t nb3,
|
||||||
|
cudaStream_t stream
|
||||||
|
) {
|
||||||
|
const int64_t total_batches = ne02 * ne03;
|
||||||
|
|
||||||
|
// Allocate pointer arrays on device
|
||||||
|
ggml_cuda_pool_alloc<const float *> A_ptrs(ctx.pool(), total_batches);
|
||||||
|
ggml_cuda_pool_alloc<float *> X_ptrs(ctx.pool(), total_batches);
|
||||||
|
|
||||||
|
// Set up pointer arrays on device (CUDA graph compatible)
|
||||||
|
{
|
||||||
|
const int block_size = 256;
|
||||||
|
const int grid_size = (total_batches + block_size - 1) / block_size;
|
||||||
|
setup_trsm_batch_pointers<<<grid_size, block_size, 0, stream>>>(
|
||||||
|
A, X,
|
||||||
|
A_ptrs.get(), X_ptrs.get(),
|
||||||
|
ne02, total_batches,
|
||||||
|
nb02, nb03, nb2, nb3
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get cuBLAS handle and set stream
|
||||||
|
cublasHandle_t handle = ctx.cublas_handle();
|
||||||
|
cublasSetStream(handle, stream);
|
||||||
|
|
||||||
|
// Save current math mode and set to default for accuracy
|
||||||
|
// (TF32 can cause numerical issues with triangular solves)
|
||||||
|
cublasMath_t prev_math_mode;
|
||||||
|
cublasGetMathMode(handle, &prev_math_mode);
|
||||||
|
cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);
|
||||||
|
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
|
||||||
|
cublasStatus_t status = cublasStrsmBatched(
|
||||||
|
handle,
|
||||||
|
CUBLAS_SIDE_RIGHT, // A is on the right: X * A = B
|
||||||
|
CUBLAS_FILL_MODE_UPPER, // A^T is upper (since A is lower in row-major)
|
||||||
|
CUBLAS_OP_N, // No additional transpose
|
||||||
|
CUBLAS_DIAG_NON_UNIT, // Diagonal is not assumed to be 1
|
||||||
|
k, // m: rows of X^T (columns of X)
|
||||||
|
n, // n: columns of X^T (rows of X) = size of A
|
||||||
|
&alpha,
|
||||||
|
(const float **)A_ptrs.get(), n, // lda = n (leading dimension)
|
||||||
|
(float **)X_ptrs.get(), k, // ldb = k (leading dimension of X^T)
|
||||||
|
total_batches
|
||||||
|
);
|
||||||
|
|
||||||
|
// Restore previous math mode
|
||||||
|
cublasSetMathMode(handle, prev_math_mode);
|
||||||
|
|
||||||
|
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR("cuBLAS batched TRSM failed: %d\n", (int)status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void solve_tri_f32_cuda(const float * A,
|
static void solve_tri_f32_cuda(const float * A,
|
||||||
const float * B,
|
const float * B,
|
||||||
float * X,
|
float * X,
|
||||||
|
|
@ -195,81 +746,133 @@ static void solve_tri_f32_cuda(const float * A,
|
||||||
size_t nb3,
|
size_t nb3,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||||
dim3 threads(WARP_SIZE, k);
|
dim3 grid(ne02 * ne03);
|
||||||
dim3 grid(ne02 * ne03);
|
|
||||||
|
// Handle large matrices first (256×256 and 65-128 range)
|
||||||
|
|
||||||
|
// Route sizes 65-256 to the tiled kernel
|
||||||
|
if (n > 64 || k > 64) {
|
||||||
|
// Use the tiled kernel which works for any size up to 256
|
||||||
|
// and only requires ~48KB shared memory (within standard limits)
|
||||||
|
dim3 threads_256(WARP_SIZE, 32); // 1024 threads
|
||||||
|
// Shared memory: 64×64 + 64×65 + 64×64 = 16KB + 16.25KB + 16KB = ~48KB
|
||||||
|
const size_t smem_size = (64 * 64 + 64 * 65 + 64 * 64) * sizeof(float);
|
||||||
|
|
||||||
|
// Configure extended shared memory for this kernel
|
||||||
|
static bool smem_configured_tiled = false;
|
||||||
|
if (!smem_configured_tiled) {
|
||||||
|
cudaFuncSetAttribute(solve_tri_f32_256x256_tiled,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
|
smem_configured_tiled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
solve_tri_f32_256x256_tiled<<<grid, threads_256, smem_size, stream>>>(
|
||||||
|
A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit threads_y to 32 to ensure we don't exceed 1024 threads per block (32 * 32 = 1024)
|
||||||
|
const int threads_y = k <= 32 ? k : 32;
|
||||||
|
dim3 threads(WARP_SIZE, threads_y);
|
||||||
|
|
||||||
if (n == 64) {
|
if (n == 64) {
|
||||||
switch (k) {
|
switch (k) {
|
||||||
|
case 64:
|
||||||
|
{
|
||||||
|
// Use optimized kernel for n=64, k=64 case (common in Qwen3 Next DeltaNet)
|
||||||
|
// Block config: 32x32 = 1024 threads (32 warps)
|
||||||
|
dim3 threads_64x64(WARP_SIZE, 32);
|
||||||
|
solve_tri_f32_64x64_opt
|
||||||
|
<<<grid, threads_64x64, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 48:
|
||||||
|
// k=48 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2)
|
||||||
|
solve_tri_f32_fast<64, 48, 32>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 40:
|
||||||
|
// k=40 needs 2 columns per thread (threads_y=32, some threads handle 1, some 2)
|
||||||
|
solve_tri_f32_fast<64, 40, 32>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
case 32:
|
case 32:
|
||||||
solve_tri_f32_fast<64, 32>
|
solve_tri_f32_fast<64, 32, 32>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
solve_tri_f32_fast<64, 16>
|
solve_tri_f32_fast<64, 16, 16>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 14:
|
case 14:
|
||||||
solve_tri_f32_fast<64, 14>
|
solve_tri_f32_fast<64, 14, 14>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 12:
|
case 12:
|
||||||
solve_tri_f32_fast<64, 12>
|
solve_tri_f32_fast<64, 12, 12>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 10:
|
case 10:
|
||||||
solve_tri_f32_fast<64, 10>
|
solve_tri_f32_fast<64, 10, 10>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
solve_tri_f32_fast<64, 8>
|
solve_tri_f32_fast<64, 8, 8>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 6:
|
case 6:
|
||||||
solve_tri_f32_fast<64, 6>
|
solve_tri_f32_fast<64, 6, 6>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
solve_tri_f32_fast<64, 4>
|
solve_tri_f32_fast<64, 4, 4>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
solve_tri_f32_fast<64, 2>
|
solve_tri_f32_fast<64, 2, 2>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
solve_tri_f32_fast<64, 1>
|
solve_tri_f32_fast<64, 1, 1>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
solve_tri_f32_fast<0, 0>
|
solve_tri_f32_fast<0, 0, 0>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||||||
}
|
}
|
||||||
} else { // run general case
|
} else { // run general case
|
||||||
solve_tri_f32_fast<0, 0>
|
solve_tri_f32_fast<0, 0, 0>
|
||||||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
|
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x n matrix)
|
||||||
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
|
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
|
||||||
|
|
||||||
ggml_is_contiguous(src0);
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
ggml_is_contiguous(src1);
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
const int64_t n = src0->ne[0];
|
const int64_t n = src0->ne[0];
|
||||||
const int64_t k = src1->ne[0];
|
const int64_t k = src1->ne[0];
|
||||||
const int64_t ne02 = src0->ne[2];
|
|
||||||
const int64_t ne03 = src0->ne[3];
|
|
||||||
|
|
||||||
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
|
const int64_t total_batches = src0->ne[2] * src0->ne[3];
|
||||||
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
|
const size_t X_size = n * k * total_batches * sizeof(float);
|
||||||
src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
|
||||||
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
// Copy B to X (cuBLAS solves in-place)
|
||||||
dst->nb[3] / sizeof(float), ctx.stream());
|
CUDA_CHECK(cudaMemcpyAsync(
|
||||||
} else {
|
dst->data, src1->data, X_size,
|
||||||
solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
|
cudaMemcpyDeviceToDevice, ctx.stream()
|
||||||
ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
));
|
||||||
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
|
||||||
dst->nb[3] / sizeof(float), ctx.stream());
|
solve_tri_f32_cublas(
|
||||||
}
|
ctx,
|
||||||
|
(const float *) src0->data,
|
||||||
|
(float *) dst->data,
|
||||||
|
n, k,
|
||||||
|
src0->ne[2], src0->ne[3],
|
||||||
|
src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
||||||
|
dst->nb[2] / sizeof(float), dst->nb[3] / sizeof(float),
|
||||||
|
ctx.stream()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1028,6 +1028,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"GATED_LINEAR_ATTN",
|
"GATED_LINEAR_ATTN",
|
||||||
"RWKV_WKV7",
|
"RWKV_WKV7",
|
||||||
"SOLVE_TRI",
|
"SOLVE_TRI",
|
||||||
|
"DELTA_NET",
|
||||||
|
|
||||||
"UNARY",
|
"UNARY",
|
||||||
|
|
||||||
|
|
@ -1045,7 +1046,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"GLU",
|
"GLU",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
|
@ -1137,6 +1138,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"gated_linear_attn(k, v, q, gate, s)",
|
"gated_linear_attn(k, v, q, gate, s)",
|
||||||
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
"rwkv_wkv7(r, w, k, v, a, b, s)",
|
||||||
"A X = B, A triangular, solve X",
|
"A X = B, A triangular, solve X",
|
||||||
|
"delta_net(q, k, v, g, beta, state)",
|
||||||
|
|
||||||
"unary(x)",
|
"unary(x)",
|
||||||
|
|
||||||
|
|
@ -1154,7 +1156,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"glu(x)",
|
"glu(x)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
|
@ -6093,6 +6095,63 @@ struct ggml_tensor * ggml_solve_tri(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// delta_net
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_delta_net(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * g,
|
||||||
|
struct ggml_tensor * beta,
|
||||||
|
struct ggml_tensor * state) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(q));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(k));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(v));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(g));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(beta));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(state));
|
||||||
|
|
||||||
|
GGML_ASSERT(q->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(k->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(v->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(g->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(beta->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(state->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int64_t S_k = q->ne[0];
|
||||||
|
const int64_t n_tokens = q->ne[1];
|
||||||
|
const int64_t H_k = q->ne[2];
|
||||||
|
const int64_t n_seqs = q->ne[3];
|
||||||
|
|
||||||
|
const int64_t S_v = v->ne[0];
|
||||||
|
const int64_t H_v = v->ne[2];
|
||||||
|
|
||||||
|
GGML_UNUSED(S_k);
|
||||||
|
|
||||||
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[2] == H_k && g->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(H_k == H_v);
|
||||||
|
|
||||||
|
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||||
|
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size);
|
||||||
|
|
||||||
|
result->op = GGML_OP_DELTA_NET;
|
||||||
|
result->src[0] = q;
|
||||||
|
result->src[1] = k;
|
||||||
|
result->src[2] = v;
|
||||||
|
result->src[3] = g;
|
||||||
|
result->src[4] = beta;
|
||||||
|
result->src[5] = state;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
||||||
|
|
|
||||||
|
|
@ -460,6 +460,15 @@ private:
|
||||||
ggml_tensor * diag_mask,
|
ggml_tensor * diag_mask,
|
||||||
int il);
|
int il);
|
||||||
|
|
||||||
|
ggml_tensor * build_delta_net_fused(
|
||||||
|
ggml_tensor * q,
|
||||||
|
ggml_tensor * k,
|
||||||
|
ggml_tensor * v,
|
||||||
|
ggml_tensor * g,
|
||||||
|
ggml_tensor * beta,
|
||||||
|
ggml_tensor * state,
|
||||||
|
int il);
|
||||||
|
|
||||||
ggml_tensor * build_delta_net_autoregressive(
|
ggml_tensor * build_delta_net_autoregressive(
|
||||||
ggml_tensor * q,
|
ggml_tensor * q,
|
||||||
ggml_tensor * k,
|
ggml_tensor * k,
|
||||||
|
|
@ -467,7 +476,7 @@ private:
|
||||||
ggml_tensor * g,
|
ggml_tensor * g,
|
||||||
ggml_tensor * beta,
|
ggml_tensor * beta,
|
||||||
ggml_tensor * state,
|
ggml_tensor * state,
|
||||||
int il);
|
int il);
|
||||||
|
|
||||||
ggml_tensor * build_norm_gated(
|
ggml_tensor * build_norm_gated(
|
||||||
ggml_tensor * input,
|
ggml_tensor * input,
|
||||||
|
|
|
||||||
|
|
@ -33,12 +33,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
||||||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||||
cb(cur, "attn_norm", il);
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
// Determine layer type and build appropriate attention mechanism
|
|
||||||
if (hparams.is_recurrent(il)) {
|
if (hparams.is_recurrent(il)) {
|
||||||
// Linear attention layer (gated delta net)
|
|
||||||
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
|
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
|
||||||
} else {
|
} else {
|
||||||
// Full attention layer
|
|
||||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
|
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,37 +44,28 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Residual connection
|
|
||||||
cur = ggml_add(ctx0, cur, inpSA);
|
cur = ggml_add(ctx0, cur, inpSA);
|
||||||
cb(cur, "attn_residual", il);
|
cb(cur, "attn_residual", il);
|
||||||
|
|
||||||
// Save the tensor before post-attention norm for residual connection
|
|
||||||
ggml_tensor * ffn_residual = cur;
|
ggml_tensor * ffn_residual = cur;
|
||||||
|
|
||||||
// Post-attention norm
|
|
||||||
ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||||
cb(attn_post_norm, "attn_post_norm", il);
|
cb(attn_post_norm, "attn_post_norm", il);
|
||||||
|
|
||||||
// FFN layer (MoE or dense) - without residual connection
|
|
||||||
cur = build_layer_ffn(attn_post_norm, il);
|
cur = build_layer_ffn(attn_post_norm, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
// Residual connection for FFN - add to the tensor from before post_attention_layernorm
|
|
||||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||||
cb(cur, "post_moe", il);
|
cb(cur, "post_moe", il);
|
||||||
|
|
||||||
// Input for next layer
|
|
||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
cur = inpL;
|
cur = inpL;
|
||||||
|
|
||||||
// Final norm
|
|
||||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||||
|
|
||||||
cb(cur, "result_norm", -1);
|
cb(cur, "result_norm", -1);
|
||||||
res->t_embd = cur;
|
res->t_embd = cur;
|
||||||
|
|
||||||
// LM head
|
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
|
|
@ -426,6 +414,78 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
|
||||||
return ggml_concat(ctx0, flat_output, flat_state, 0);
|
return ggml_concat(ctx0, flat_output, flat_state, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_tensor * llm_build_qwen3next::build_delta_net_fused(
|
||||||
|
ggml_tensor * q,
|
||||||
|
ggml_tensor * k,
|
||||||
|
ggml_tensor * v,
|
||||||
|
ggml_tensor * g,
|
||||||
|
ggml_tensor * beta,
|
||||||
|
ggml_tensor * state,
|
||||||
|
int il) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(q));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(k));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(v));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(g));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(beta));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(state));
|
||||||
|
|
||||||
|
const int64_t S_k = q->ne[0];
|
||||||
|
const int64_t H_k = q->ne[1];
|
||||||
|
const int64_t n_tokens = q->ne[2];
|
||||||
|
const int64_t n_seqs = q->ne[3];
|
||||||
|
|
||||||
|
const int64_t S_v = v->ne[0];
|
||||||
|
const int64_t H_v = v->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||||
|
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||||
|
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||||
|
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||||
|
|
||||||
|
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||||
|
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||||
|
|
||||||
|
GGML_ASSERT(H_k == H_v);
|
||||||
|
|
||||||
|
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||||
|
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||||
|
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||||
|
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 1, 3, 0, 2), n_tokens, 1, H_k, n_seqs);
|
||||||
|
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 1, 2, 0, 3), 1, n_tokens, H_k, n_seqs);
|
||||||
|
|
||||||
|
cb(q, "q_fused", il);
|
||||||
|
cb(k, "k_fused", il);
|
||||||
|
cb(v, "v_fused", il);
|
||||||
|
cb(g, "g_fused", il);
|
||||||
|
cb(beta, "beta_fused", il);
|
||||||
|
|
||||||
|
ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state);
|
||||||
|
cb(fused_result, "delta_net_fused_raw", il);
|
||||||
|
|
||||||
|
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||||
|
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||||
|
|
||||||
|
ggml_tensor * output_4d = ggml_view_4d(ctx0, fused_result,
|
||||||
|
S_v, H_v, n_tokens, n_seqs,
|
||||||
|
S_v * ggml_element_size(fused_result),
|
||||||
|
S_v * H_v * ggml_element_size(fused_result),
|
||||||
|
S_v * H_v * n_tokens * ggml_element_size(fused_result),
|
||||||
|
0);
|
||||||
|
cb(output_4d, "fused_output_4d", il);
|
||||||
|
|
||||||
|
ggml_tensor * flat_output = ggml_cont_1d(ctx0, output_4d, output_size);
|
||||||
|
cb(flat_output, "fused_flat_output", il);
|
||||||
|
|
||||||
|
ggml_tensor * flat_state = ggml_view_1d(ctx0, fused_result, state_size,
|
||||||
|
output_size * ggml_element_size(fused_result));
|
||||||
|
cb(flat_state, "fused_flat_state", il);
|
||||||
|
|
||||||
|
ggml_tensor * result = ggml_concat(ctx0, flat_output, flat_state, 0);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_build_qwen3next::build_norm_gated(
|
ggml_tensor * llm_build_qwen3next::build_norm_gated(
|
||||||
ggml_tensor * input,
|
ggml_tensor * input,
|
||||||
ggml_tensor * weights,
|
ggml_tensor * weights,
|
||||||
|
|
@ -445,16 +505,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
// Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
|
|
||||||
|
|
||||||
// Qwen3Next uses a single Q projection that outputs query + gate
|
|
||||||
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
|
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
|
||||||
cb(Qcur_full, "Qcur_full", il);
|
cb(Qcur_full, "Qcur_full", il);
|
||||||
|
|
||||||
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
|
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
|
||||||
|
|
||||||
// Split Q projection into query and gate
|
|
||||||
// The split should be along dimension 0 (the feature dimension)
|
|
||||||
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
|
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
|
||||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
||||||
ggml_tensor * gate =
|
ggml_tensor * gate =
|
||||||
|
|
@ -463,11 +518,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
cb(gate, "gate", il);
|
cb(gate, "gate", il);
|
||||||
|
|
||||||
// Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
|
|
||||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
cb(Qcur, "Qcur_reshaped", il);
|
cb(Qcur, "Qcur_reshaped", il);
|
||||||
|
|
||||||
// Apply Q normalization
|
|
||||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||||
cb(Qcur, "Qcur_normed", il);
|
cb(Qcur, "Qcur_normed", il);
|
||||||
|
|
||||||
|
|
@ -477,18 +530,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
// Apply K normalization
|
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
||||||
cb(Kcur, "Kcur_normed", il);
|
cb(Kcur, "Kcur_normed", il);
|
||||||
|
|
||||||
// Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
|
|
||||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||||
cb(gate, "gate_reshaped", il);
|
cb(gate, "gate_reshaped", il);
|
||||||
|
|
||||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||||
|
|
||||||
// Apply RoPE
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, Qcur, inp_pos, nullptr,
|
ctx0, Qcur, inp_pos, nullptr,
|
||||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
|
@ -503,7 +553,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
// Attention computation
|
|
||||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||||
|
|
||||||
cur = build_attn(inp,
|
cur = build_attn(inp,
|
||||||
|
|
@ -737,13 +786,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||||
cb(k_conv, "k_conv_predelta", il);
|
cb(k_conv, "k_conv_predelta", il);
|
||||||
cb(v_conv, "v_conv_predelta", il);
|
cb(v_conv, "v_conv_predelta", il);
|
||||||
|
|
||||||
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
|
ggml_tensor * attn_out = build_delta_net_fused(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||||
ggml_tensor * attn_out;
|
|
||||||
if (n_seq_tokens == 1) {
|
|
||||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
|
||||||
} else {
|
|
||||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
|
||||||
}
|
|
||||||
cb(attn_out, "attn_out", il);
|
cb(attn_out, "attn_out", il);
|
||||||
|
|
||||||
// The tensors were concatenated 1d, so we need to extract them 1d as well
|
// The tensors were concatenated 1d, so we need to extract them 1d as well
|
||||||
|
|
@ -795,9 +838,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) {
|
ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) {
|
||||||
// Check if this is an MoE layer
|
|
||||||
if (model.layers[il].ffn_gate_inp != nullptr) {
|
if (model.layers[il].ffn_gate_inp != nullptr) {
|
||||||
// MoE branch
|
|
||||||
ggml_tensor * moe_out =
|
ggml_tensor * moe_out =
|
||||||
build_moe_ffn(cur,
|
build_moe_ffn(cur,
|
||||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
||||||
|
|
@ -807,7 +848,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
||||||
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
|
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
|
||||||
cb(moe_out, "ffn_moe_out", il);
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
// Add shared experts if present - following Qwen3Next reference implementation
|
|
||||||
if (model.layers[il].ffn_up_shexp != nullptr) {
|
if (model.layers[il].ffn_up_shexp != nullptr) {
|
||||||
ggml_tensor * ffn_shexp =
|
ggml_tensor * ffn_shexp =
|
||||||
build_ffn(cur,
|
build_ffn(cur,
|
||||||
|
|
@ -818,23 +858,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
||||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||||
cb(ffn_shexp, "ffn_shexp", il);
|
cb(ffn_shexp, "ffn_shexp", il);
|
||||||
|
|
||||||
// Apply shared expert gating as in the reference implementation
|
|
||||||
// The shared expert has its own gate that is sigmoided
|
|
||||||
// Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
|
|
||||||
ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
|
ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
|
||||||
cb(shared_gate, "shared_expert_gate", il);
|
cb(shared_gate, "shared_expert_gate", il);
|
||||||
|
|
||||||
// Apply sigmoid to the gate
|
|
||||||
shared_gate = ggml_sigmoid(ctx0, shared_gate);
|
shared_gate = ggml_sigmoid(ctx0, shared_gate);
|
||||||
cb(shared_gate, "shared_expert_gate_sigmoid", il);
|
cb(shared_gate, "shared_expert_gate_sigmoid", il);
|
||||||
|
|
||||||
// The gate needs to be broadcast to match the dimensions of ffn_shexp
|
|
||||||
// ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
|
|
||||||
// We need to repeat the gate along the feature dimension
|
|
||||||
shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
|
shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
|
||||||
cb(shared_gate, "shared_expert_gate_broadcast", il);
|
cb(shared_gate, "shared_expert_gate_broadcast", il);
|
||||||
|
|
||||||
// Apply the gate to the shared expert output
|
|
||||||
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
|
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
|
||||||
cb(ffn_shexp, "ffn_shexp_gated", il);
|
cb(ffn_shexp, "ffn_shexp_gated", il);
|
||||||
|
|
||||||
|
|
@ -844,7 +876,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
||||||
cur = moe_out;
|
cur = moe_out;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Dense FFN branch (not currently used I believe)
|
|
||||||
cur = build_ffn(cur,
|
cur = build_ffn(cur,
|
||||||
model.layers[il].ffn_up, NULL, NULL,
|
model.layers[il].ffn_up, NULL, NULL,
|
||||||
model.layers[il].ffn_gate, NULL, NULL,
|
model.layers[il].ffn_gate, NULL, NULL,
|
||||||
|
|
|
||||||
|
|
@ -3550,6 +3550,34 @@ struct test_rwkv_wkv7 : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_DELTA_NET
|
||||||
|
struct test_delta_net : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
|
||||||
|
const int64_t n_heads;
|
||||||
|
const int64_t head_dim;
|
||||||
|
const int64_t n_tokens;
|
||||||
|
const int64_t n_seqs;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR5(type, n_heads, head_dim, n_tokens, n_seqs);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_delta_net(ggml_type type = GGML_TYPE_F32,
|
||||||
|
int64_t n_heads = 8, int64_t head_dim = 64, int64_t n_tokens = 32, int64_t n_seqs = 2)
|
||||||
|
: type(type), n_heads(n_heads), head_dim(head_dim), n_tokens(n_tokens), n_seqs(n_seqs) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||||
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||||
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_dim, n_tokens, n_heads, n_seqs);
|
||||||
|
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, n_tokens, 1, n_heads, n_seqs);
|
||||||
|
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, n_tokens, n_heads, n_seqs);
|
||||||
|
ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_dim, head_dim * n_heads, 1, n_seqs);
|
||||||
|
return ggml_delta_net(ctx, q, k, v, g, beta, state);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_MUL_MAT
|
// GGML_OP_MUL_MAT
|
||||||
struct test_mul_mat : public test_case {
|
struct test_mul_mat : public test_case {
|
||||||
const ggml_type type_a;
|
const ggml_type type_a;
|
||||||
|
|
@ -7322,6 +7350,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
|
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 1, 1));
|
||||||
|
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 1));
|
||||||
|
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 32, 2));
|
||||||
|
test_cases.emplace_back(new test_delta_net(GGML_TYPE_F32, 8, 64, 128, 2));
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// > 4GB A matrix. Too slow to be enabled by default.
|
// > 4GB A matrix. Too slow to be enabled by default.
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue