This commit is contained in:
Aman Gupta 2026-02-13 21:58:24 +08:00 committed by GitHub
commit 28b9f9ef22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 477 additions and 2 deletions

View File

@ -556,6 +556,7 @@ extern "C" {
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_GATED_DELTA_NET,
GGML_OP_UNARY,
@ -2466,6 +2467,15 @@ extern "C" {
bool lower,
bool uni);
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state);
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

View File

@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_solve_tri(params, tensor);
} break;
case GGML_OP_GATED_DELTA_NET:
{
ggml_compute_forward_gated_delta_net(params, tensor);
} break;
case GGML_OP_MAP_CUSTOM1:
{
ggml_compute_forward_map_custom1(params, tensor);
@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_COUNT_EQUAL:
case GGML_OP_SOLVE_TRI:
case GGML_OP_GATED_DELTA_NET:
{
n_tasks = n_threads;
} break;
@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan(
{
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_OP_GATED_DELTA_NET:
{
const int64_t S_v = node->src[2]->ne[0];
cur = (S_v * S_v + S_v) * sizeof(float) * n_tasks;
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");

View File

@ -10360,6 +10360,192 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
}
}
// ggml_compute_forward_gated_delta_net
static void ggml_compute_forward_gated_delta_net_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst,
int64_t ir0,
int64_t ir1) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
const int64_t S_v = src_v->ne[0];
const int64_t H = src_v->ne[1];
const int64_t n_tokens = src_v->ne[2];
const int64_t n_seqs = src_v->ne[3];
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
// TODO: to support KDA
GGML_ASSERT(ggml_are_same_shape(src_beta, src_g));
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
// scratch layout per thread: [s_t(S_v*S_v) | delta(S_v)]
// s_t holds the transposed (row-major) state for contiguous vector ops
const int64_t scratch_per_thread = S_v * S_v + S_v;
const int ith = params->ith;
float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
float * s_t = scratch;
float * delta = scratch + S_v * S_v;
// output layout: [attn_scores | new_states]
// attn_scores: S_v * H * n_tokens * n_seqs floats
// new_states: S_v * S_v * H * n_seqs floats
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_out_base = (float *)dst->data;
float * state_out_base = (float *)dst->data + attn_score_elems;
const float * state_in_base = (const float *)src_state->data;
const float * g_base = (const float *)src_g->data;
const float * beta_base = (const float *)src_beta->data;
const int64_t rq1 = nev1 / neq1;
const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3;
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence
const int64_t iq1 = iv1 / rq1;
const int64_t ik1 = iv1 / rk1;
const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
// tranpose
const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
for (int64_t j = 0; j < S_v; ++j) {
for (int64_t i = 0; i < S_v; ++i) {
s_t[j * S_v + i] = s_in[j + i * S_v];
}
}
// attn output pointer for first token of this (head, seq)
float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
for (int64_t t = 0; t < n_tokens; t++) {
const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
const int64_t gb_offset = iv3 * neg1 * neg0 + t * neg0 + iv1;
const float beta_val_raw = beta_base[gb_offset];
const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid
const float g_val = expf(g_base[gb_offset]);
ggml_vec_scale_f32(S_v * S_v, s_t, g_val);
for (int64_t j = 0; j < S_v; ++j) {
float kv_j;
ggml_vec_dot_f32(S_v, &kv_j, 0, &s_t[j * S_v], 0, k_d, 0, 1);
delta[j] = (v_d[j] - kv_j) * beta_val;
}
// outer product: S[j][i] += k[i] * delta[j]
for (int64_t j = 0; j < S_v; ++j) {
ggml_vec_mad_f32(S_v, &s_t[j * S_v], k_d, delta[j]);
}
// attn_out[j] = sum_i S[j][i] * q[i] = dot(s_t[j*S_v:], q)
for (int64_t j = 0; j < S_v; ++j) {
ggml_vec_dot_f32(S_v, &attn_data[j], 0, &s_t[j * S_v], 0, q_d, 0, 1);
}
attn_data += S_v * H; // advance to next token
}
// transpose back
for (int64_t j = 0; j < S_v; ++j) {
for (int64_t i = 0; i < S_v; ++i) {
s_out[j + i * S_v] = s_t[j * S_v + i];
}
}
}
}
static void ggml_compute_forward_gated_delta_net_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
ggml_tensor * V = dst->src[2];
int64_t nr = V->ne[1] * V->ne[3];
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
int nth = params->nth;
int ith = params->ith;
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
const int64_t dr = (nr + nchunk - 1) / nchunk;
int current_chunk = ith;
while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
}
void ggml_compute_forward_gated_delta_net(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_gated_delta_net_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_rwkv_wkv7
static void ggml_compute_forward_rwkv_wkv7_f32(

View File

@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -0,0 +1,160 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"
template <int S_v>
__global__ void gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sg1,
int64_t sg2,
int64_t rq1,
int64_t rq3) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
const int64_t iq1 = h_idx / rq1;
const int64_t iq3 = sequence / rq3;
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
float * state = dst + attn_score_elems;
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
// Copy input state to output state (working area)
#pragma unroll
for (int i = 0; i < S_v; i++) {
state[i * S_v + col] = curr_state[i * S_v + col];
}
for (int t = 0; t < n_tokens; t++) {
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
const float * g_t = g + sequence * sg2 + t * sg1;
const float * beta_t = beta + sequence * sg2 + t * sg1;
const float beta_val = 1.0f / (1.0f + expf(-beta_t[h_idx]));
const float g_val = expf(g_t[h_idx]);
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += state[i * S_v + col] * k_t[i];
}
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
#pragma unroll
for (int i = 0; i < S_v; i++) {
state[i * S_v + col] = g_val * state[i * S_v + col] + k_t[i] * delta_col;
}
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
attn_col += state[i * S_v + col] * q_t[i];
}
attn_data[col] = attn_col;
attn_data += S_v * H;
}
}
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
const int64_t n_tokens = nev2;
const int64_t n_seqs = nev3;
const int64_t rq1 = nev1 / neq1;
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
const float * k_d = (const float *) src_k->data;
const float * v_d = (const float *) src_v->data;
const float * g_d = (const float *) src_g->data;
const float * b_d = (const float *) src_beta->data;
const float * s_d = (const float *) src_state->data;
float * dst_d = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
GGML_ASSERT(ggml_are_same_stride(src_g, src_beta));
GGML_ASSERT(ggml_is_contiguous(src_g));
GGML_ASSERT(ggml_is_contiguous(src_beta));
GGML_ASSERT(ggml_is_contiguous(src_state));
// strides in floats
const int64_t sq1 = nbq1 / sizeof(float);
const int64_t sq2 = nbq2 / sizeof(float);
const int64_t sq3 = nbq3 / sizeof(float);
const int64_t sv1 = nbv1 / sizeof(float);
const int64_t sv2 = nbv2 / sizeof(float);
const int64_t sv3 = nbv3 / sizeof(float);
const int64_t sg1 = nbg1 / sizeof(float);
const int64_t sg2 = nbg2 / sizeof(float);
dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
cudaStream_t stream = ctx.stream();
switch (S_v) {
case 32:
gated_delta_net_cuda<32><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
sv3, sg1, sg2, rq1, rq3);
break;
case 64:
gated_delta_net_cuda<64><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
sv3, sg1, sg2, rq1, rq3);
break;
case 128:
gated_delta_net_cuda<128><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2,
sv3, sg1, sg2, rq1, rq3);
break;
default:
GGML_ABORT("fatal error");
break;
}
}

View File

@ -0,0 +1,4 @@
#include "common.cuh"
#include "ggml.h"
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -53,6 +53,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/gated_delta_net.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
@ -2730,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cuda_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_GATED_DELTA_NET:
ggml_cuda_op_gated_delta_net(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_cuda_op_rwkv_wkv7(ctx, dst);
break;
@ -4849,6 +4853,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_GATED_DELTA_NET:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_FLASH_ATTN_EXT:

View File

@ -1030,6 +1030,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GATED_LINEAR_ATTN",
"RWKV_WKV7",
"SOLVE_TRI",
"GATED_DELTA_NET",
"UNARY",
@ -1047,7 +1048,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1139,6 +1140,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gated_linear_attn(k, v, q, gate, s)",
"rwkv_wkv7(r, w, k, v, a, b, s)",
"A X = B, A triangular, solve X",
"gated_delta_net(q, k, v, g, beta, s)",
"unary(x)",
@ -1156,7 +1158,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -6101,6 +6103,53 @@ struct ggml_tensor * ggml_solve_tri(
return result;
}
// ggml_gated_delta_net
struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state) {
GGML_ASSERT(ggml_is_contiguous_rows(q));
GGML_ASSERT(ggml_is_contiguous_rows(k));
GGML_ASSERT(ggml_is_contiguous_rows(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));
GGML_ASSERT(q->type == GGML_TYPE_F32);
GGML_ASSERT(k->type == GGML_TYPE_F32);
GGML_ASSERT(v->type == GGML_TYPE_F32);
GGML_ASSERT(g->type == GGML_TYPE_F32);
GGML_ASSERT(beta->type == GGML_TYPE_F32);
GGML_ASSERT(state->type == GGML_TYPE_F32);
const int64_t S_v = v->ne[0];
const int64_t H = v->ne[1];
const int64_t n_tokens = v->ne[2];
const int64_t n_seqs = v->ne[3];
GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
// concat output and new_state into a single tensor
// output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_GATED_DELTA_NET;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
result->src[3] = g;
result->src[4] = beta;
result->src[5] = state;
return result;
}
////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) {

View File

@ -3656,6 +3656,48 @@ struct test_rwkv_wkv6 : public test_case {
}
};
// GGML_OP_GATED_DELTA_NET
struct test_gated_delta_net : public test_case {
const ggml_type type;
const int64_t head_count;
const int64_t head_size;
const int64_t n_seq_tokens;
const int64_t n_seqs;
const int v_repeat;
std::string vars() override {
return VARS_TO_STR7(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted);
}
const bool permuted;
test_gated_delta_net(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, int v_repeat = 1, bool permuted = false)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), v_repeat(v_repeat), permuted(permuted) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q;
ggml_tensor * k;
ggml_tensor * v;
if (permuted) {
// create with dims 1 and 2 swapped, then permute back to get non-contiguous layout
q = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3);
k = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count, n_seqs), 0, 2, 1, 3);
v = ggml_permute(ctx, ggml_new_tensor_4d(ctx, type, head_size, n_seq_tokens, head_count * v_repeat, n_seqs), 0, 2, 1, 3);
} else {
q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
v = ggml_new_tensor_4d(ctx, type, head_size, head_count * v_repeat, n_seq_tokens, n_seqs);
}
ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * v_repeat * head_size * head_count, n_seqs);
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state);
return out;
}
};
// GGML_OP_GATED_LINEAR_ATTN
struct test_gla : public test_case {
const ggml_type type;
@ -8373,6 +8415,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1, 1, true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));