ggml: add GATED_DELTA_NET op

This commit is contained in:
Aman Gupta 2026-02-11 10:52:49 +05:30
parent 9a96352729
commit 8666c546f9
5 changed files with 258 additions and 2 deletions

View File

@ -556,6 +556,7 @@ extern "C" {
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_GATED_DELTA_NET,
GGML_OP_UNARY,
@ -2466,6 +2467,15 @@ extern "C" {
bool lower,
bool uni);
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_GATED_DELTA_NET:
{
ggml_compute_forward_gated_delta_net(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_COUNT_EQUAL:
case GGML_OP_SOLVE_TRI:
case GGML_OP_GATED_DELTA_NET:
{
n_tasks = n_threads;
} break;
@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan(
{
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_OP_GATED_DELTA_NET:
{
const int64_t S_v = node->src[0]->ne[0];
cur = 4 * S_v * sizeof(float) * n_tasks;
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");

View File

@ -10296,6 +10296,186 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
}
}
// ggml_compute_forward_gated_delta_net
static void ggml_compute_forward_gated_delta_net_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst,
int64_t ir0,
int64_t ir1) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
const int64_t S_v = src_q->ne[0];
const int64_t H = src_q->ne[1];
const int64_t n_tokens = src_q->ne[2];
const int64_t n_seqs = src_q->ne[3];
GGML_ASSERT(ggml_is_contiguous(src_q));
GGML_ASSERT(ggml_is_contiguous(src_k));
GGML_ASSERT(ggml_is_contiguous(src_v));
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
// scratch layout per thread: [q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)]
const int64_t scratch_per_thread = 4 * S_v;
const int ith = params->ith;
float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
float * q_local = scratch;
float * k_local = scratch + S_v;
float * kv_mem = scratch + 2 * S_v;
float * delta = scratch + 3 * S_v;
// output layout: [attn_scores | new_states]
// attn_scores: S_v * H * n_tokens * n_seqs floats
// new_states: S_v * S_v * H * n_seqs floats
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_out_base = (float *)dst->data;
float * state_out_base = (float *)dst->data + attn_score_elems;
const float * state_in_base = (const float *)src_state->data;
const float * q_base = (const float *)src_q->data;
const float * k_base = (const float *)src_k->data;
const float * v_base = (const float *)src_v->data;
const float * g_base = (const float *)src_g->data;
const float * beta_base = (const float *)src_beta->data;
const float eps = ggml_get_op_params_f32(dst, 0);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t h_idx = ir % H;
const int64_t sequence = ir / H;
// output state pointer for this (head, seq)
float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v;
// copy input state for this (head, seq) into output
const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v;
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
// attn output pointer for first token of this (head, seq)
float * attn_data = attn_out_base + (sequence * n_tokens * H + h_idx) * S_v;
for (int64_t t = 0; t < n_tokens; t++) {
// input pointers for this (head, seq, token)
// layout is contiguous [S_v, H, n_tokens, n_seqs]
const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v;
const float * q_d = q_base + qkv_offset;
const float * k_d = k_base + qkv_offset;
const float * v_d = v_base + qkv_offset;
// g and beta layout: [H, n_tokens, n_seqs]
const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx;
const float beta_val_raw = beta_base[gb_offset];
const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid
const float g_val = expf(g_base[gb_offset]);
memcpy(q_local, q_d, S_v * sizeof(float));
memcpy(k_local, k_d, S_v * sizeof(float));
// l2-norm q and scale by 1/sqrt(S_v)
float norm;
ggml_vec_norm_f32(S_v, &norm, q_local);
ggml_vec_scale_f32(S_v, q_local, 1.0f / fmaxf(norm, eps));
ggml_vec_scale_f32(S_v, q_local, 1.0f / sqrtf((float)S_v));
// l2-norm k
ggml_vec_norm_f32(S_v, &norm, k_local);
ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps));
// state decay: S *= exp(g)
ggml_vec_scale_f32(S_v * S_v, s_out, g_val);
// kv_mem = S @ k
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_dot_f32(S_v, &kv_mem[i], 0, &s_out[i * S_v], 0, k_local, 0, 1);
}
// delta = (v - kv_mem) * beta
for (int64_t i = 0; i < S_v; ++i) {
delta[i] = (v_d[i] - kv_mem[i]) * beta_val;
}
// outer product update: S += k (x) delta
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_local[i]);
}
// attn output = S @ q
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_dot_f32(S_v, &attn_data[i], 0, &s_out[i * S_v], 0, q_local, 0, 1);
}
attn_data += S_v * H; // advance to next token
}
}
}
static void ggml_compute_forward_gated_delta_net_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
ggml_tensor * Q = dst->src[0];
int64_t nr = Q->ne[1] * Q->ne[3];
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
int nth = params->nth;
int ith = params->ith;
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
const int64_t dr = (nr + nchunk - 1) / nchunk;
int current_chunk = ith;
while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
}
void ggml_compute_forward_gated_delta_net(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gated_delta_net_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_rwkv_wkv7
static void ggml_compute_forward_rwkv_wkv7_f32(

View File

@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GATED_LINEAR_ATTN",
"RWKV_WKV7",
"SOLVE_TRI",
"GATED_DELTA_NET",
"UNARY",
@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"A X = B, A triangular, solve X",
"gated_delta_net(q, k, v, g, beta, s)",
"unary(x)",
@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -6101,6 +6103,59 @@ struct ggml_tensor * ggml_solve_tri(
return result;
}
// ggml_gated_delta_net
struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
GGML_ASSERT(q->type == GGML_TYPE_F32);
GGML_ASSERT(k->type == GGML_TYPE_F32);
GGML_ASSERT(v->type == GGML_TYPE_F32);
GGML_ASSERT(g->type == GGML_TYPE_F32);
GGML_ASSERT(beta->type == GGML_TYPE_F32);
GGML_ASSERT(state->type == GGML_TYPE_F32);
const int64_t S_k = q->ne[0];
const int64_t H = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];
const int64_t S_v = v->ne[0];
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
GGML_ASSERT(v->ne[1] == H && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
GGML_ASSERT(g->ne[0] == H && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs);
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
// concat output and new_state into a single tensor
// output: S_v * H * n_tokens, state: S_v * S_v * H * n_seqs
const int64_t ne[4] = { S_v * H, n_tokens + S_v * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_GATED_DELTA_NET;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
result->src[3] = g;
result->src[4] = beta;
result->src[5] = state;
return result;
}
////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) {