simplify impl + add cuda code
This commit is contained in:
parent
86833eb747
commit
54ea122385
|
|
@ -2913,7 +2913,7 @@ struct ggml_cplan ggml_graph_plan(
|
|||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const int64_t S_v = node->src[0]->ne[0];
|
||||
cur = (S_v * S_v + 4 * S_v) * sizeof(float) * n_tasks;
|
||||
cur = (S_v * S_v + S_v) * sizeof(float) * n_tasks;
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -10322,17 +10322,14 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_state));
|
||||
|
||||
// scratch layout per thread: [s_t(S_v*S_v) | q_local(S_v) | k_local(S_v) | kv_mem(S_v) | delta(S_v)]
|
||||
// scratch layout per thread: [s_t(S_v*S_v) | delta(S_v)]
|
||||
// s_t holds the transposed (row-major) state for contiguous vector ops
|
||||
const int64_t scratch_per_thread = S_v * S_v + 4 * S_v;
|
||||
const int64_t scratch_per_thread = S_v * S_v + S_v;
|
||||
const int ith = params->ith;
|
||||
float * scratch = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
|
||||
|
||||
float * s_t = scratch;
|
||||
float * q_local = scratch + S_v * S_v;
|
||||
float * k_local = scratch + S_v * S_v + S_v;
|
||||
float * kv_mem = scratch + S_v * S_v + 2 * S_v;
|
||||
float * delta = scratch + S_v * S_v + 3 * S_v;
|
||||
float * delta = scratch + S_v * S_v;
|
||||
|
||||
// output layout: [attn_scores | new_states]
|
||||
// attn_scores: S_v * H * n_tokens * n_seqs floats
|
||||
|
|
@ -10348,19 +10345,13 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
const float * g_base = (const float *)src_g->data;
|
||||
const float * beta_base = (const float *)src_beta->data;
|
||||
|
||||
const float eps = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t h_idx = ir % H;
|
||||
const int64_t sequence = ir / H;
|
||||
|
||||
// output state pointer for this (head, seq) — column-major (ggml layout)
|
||||
float * s_out = state_out_base + (sequence * H + h_idx) * S_v * S_v;
|
||||
|
||||
// Copy state into scratch in row-major layout of S (not S^T)
|
||||
// ggml column-major: s_in[j + i*S_v] = S[j][i] (j=dim0, i=dim1)
|
||||
// row-major of S: s_t[j * S_v + i] = S[j][i] (row j is contiguous over i)
|
||||
// This makes kv_mem[j] = dot(s_t[j*S_v:], k) a contiguous dot product
|
||||
// tranpose
|
||||
const float * s_in = state_in_base + (sequence * H + h_idx) * S_v * S_v;
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
|
|
@ -10379,56 +10370,33 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
const float * k_d = k_base + qkv_offset;
|
||||
const float * v_d = v_base + qkv_offset;
|
||||
|
||||
// g and beta layout: [H, n_tokens, n_seqs]
|
||||
const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx;
|
||||
const float beta_val_raw = beta_base[gb_offset];
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_val_raw)); // sigmoid
|
||||
const float g_val = expf(g_base[gb_offset]);
|
||||
|
||||
memcpy(q_local, q_d, S_v * sizeof(float));
|
||||
memcpy(k_local, k_d, S_v * sizeof(float));
|
||||
|
||||
// l2-norm q and scale by 1/sqrt(S_v)
|
||||
float norm;
|
||||
ggml_vec_norm_f32(S_v, &norm, q_local);
|
||||
ggml_vec_scale_f32(S_v, q_local, 1.0f / fmaxf(norm, eps));
|
||||
ggml_vec_scale_f32(S_v, q_local, 1.0f / sqrtf((float)S_v));
|
||||
|
||||
// l2-norm k
|
||||
ggml_vec_norm_f32(S_v, &norm, k_local);
|
||||
ggml_vec_scale_f32(S_v, k_local, 1.0f / fmaxf(norm, eps));
|
||||
|
||||
// state decay: S *= exp(g)
|
||||
ggml_vec_scale_f32(S_v * S_v, s_t, g_val);
|
||||
|
||||
// kv_mem[j] = sum_i S[j][i] * k[i] = dot(s_t[j*S_v:], k)
|
||||
// row j of s_t is contiguous -> use ggml_vec_dot_f32
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_dot_f32(S_v, &kv_mem[j], 0, &s_t[j * S_v], 0, k_local, 0, 1);
|
||||
}
|
||||
|
||||
// delta = (v - kv_mem) * beta
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
delta[j] = (v_d[j] - kv_mem[j]) * beta_val;
|
||||
float kv_j;
|
||||
ggml_vec_dot_f32(S_v, &kv_j, 0, &s_t[j * S_v], 0, k_d, 0, 1);
|
||||
delta[j] = (v_d[j] - kv_j) * beta_val;
|
||||
}
|
||||
|
||||
// outer product: S[j][i] += k[i] * delta[j]
|
||||
// s_t[j * S_v + i] += k[i] * delta[j]
|
||||
// row j gets k[:] scaled by delta[j] -> contiguous ggml_vec_mad_f32
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_mad_f32(S_v, &s_t[j * S_v], k_local, delta[j]);
|
||||
ggml_vec_mad_f32(S_v, &s_t[j * S_v], k_d, delta[j]);
|
||||
}
|
||||
|
||||
// attn_out[j] = sum_i S[j][i] * q[i] = dot(s_t[j*S_v:], q)
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_dot_f32(S_v, &attn_data[j], 0, &s_t[j * S_v], 0, q_local, 0, 1);
|
||||
ggml_vec_dot_f32(S_v, &attn_data[j], 0, &s_t[j * S_v], 0, q_d, 0, 1);
|
||||
}
|
||||
|
||||
attn_data += S_v * H; // advance to next token
|
||||
}
|
||||
|
||||
// copy scratch back to output: row-major of S -> column-major (ggml layout)
|
||||
// s_t[j * S_v + i] = S[j][i] -> s_out[j + i * S_v] = S[j][i]
|
||||
// transpose back
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
s_out[j + i * S_v] = s_t[j * S_v + i];
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
#include "ggml-cuda/common.cuh"
|
||||
#include "gated_delta_net.cuh"
|
||||
|
||||
template<int S_v>
|
||||
__global__ void gated_delta_net_cuda(
|
||||
const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs
|
||||
) {
|
||||
const int64_t h_idx = blockIdx.x;
|
||||
const int64_t sequence = blockIdx.y;
|
||||
const int col = threadIdx.x; // each thread owns one column
|
||||
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
state += state_offset;
|
||||
curr_state += state_offset;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
// Copy input state to output state (working area)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
state[i * S_v + col] = curr_state[i * S_v + col];
|
||||
}
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
const int64_t qkv_offset = sequence * n_tokens * H * S_v + t * H * S_v + h_idx * S_v;
|
||||
const float * q_t = q + qkv_offset;
|
||||
const float * k_t = k + qkv_offset;
|
||||
const float * v_t = v + qkv_offset;
|
||||
|
||||
const int64_t gb_offset = sequence * n_tokens * H + t * H + h_idx;
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta[gb_offset]));
|
||||
const float g_val = expf(g[gb_offset]);
|
||||
|
||||
// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
|
||||
float kv_col = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
kv_col += state[i * S_v + col] * k_t[i];
|
||||
}
|
||||
|
||||
// delta[col] = (v[col] - g * kv[col]) * beta
|
||||
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
|
||||
|
||||
// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
state[i * S_v + col] = g_val * state[i * S_v + col] + k_t[i] * delta_col;
|
||||
}
|
||||
|
||||
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
|
||||
float attn_col = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
attn_col += state[i * S_v + col] * q_t[i];
|
||||
}
|
||||
attn_data[col] = attn_col;
|
||||
attn_data += S_v * H;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
ggml_tensor * src_g = dst->src[3];
|
||||
ggml_tensor * src_beta = dst->src[4];
|
||||
ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
const int64_t S_v = src_q->ne[0];
|
||||
const int64_t H = src_q->ne[1];
|
||||
const int64_t n_tokens = src_q->ne[2];
|
||||
const int64_t n_seqs = src_q->ne[3];
|
||||
|
||||
const float * q_d = (const float *) src_q->data;
|
||||
const float * k_d = (const float *) src_k->data;
|
||||
const float * v_d = (const float *) src_v->data;
|
||||
const float * g_d = (const float *) src_g->data;
|
||||
|
||||
const float * b_d = (const float *) src_beta->data;
|
||||
const float * s_d = (const float *) src_state->data;
|
||||
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src_q));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_k));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_v));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_g));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(src_state));
|
||||
|
||||
dim3 grid_dims(H, n_seqs, 1);
|
||||
dim3 block_dims(S_v, 1, 1);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
switch(S_v) {
|
||||
case 32:
|
||||
gated_delta_net_cuda<32><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
|
||||
break;
|
||||
case 64:
|
||||
gated_delta_net_cuda<64><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
|
||||
break;
|
||||
case 128:
|
||||
gated_delta_net_cuda<128><<<grid_dims, block_dims, 0, stream>>>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -53,6 +53,7 @@
|
|||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
#include "ggml-cuda/gated_delta_net.cuh"
|
||||
#include "ggml-cuda/set.cuh"
|
||||
#include "ggml-cuda/set-rows.cuh"
|
||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||
|
|
@ -2730,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
ggml_cuda_op_gated_delta_net(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
ggml_cuda_op_rwkv_wkv7(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4844,6 +4848,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
|
|
|
|||
|
|
@ -781,31 +781,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
|||
cb(k_conv, "k_conv_predelta", il);
|
||||
cb(v_conv, "v_conv_predelta", il);
|
||||
|
||||
// Choose between build_delta_net_chunking and fused ggml_gated_delta_net based on n_tokens
|
||||
ggml_tensor * output;
|
||||
ggml_tensor * new_state;
|
||||
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
|
||||
if (n_seq_tokens == 1) {
|
||||
// Fused op expects state as [S_v*S_v*H, n_seqs]
|
||||
ggml_tensor * state_2d = ggml_reshape_2d(ctx0, state, head_v_dim * head_v_dim * num_v_heads, n_seqs);
|
||||
ggml_tensor * result = ggml_gated_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state_2d,
|
||||
hparams.f_norm_rms_eps);
|
||||
|
||||
// Unpack: attn scores then new state
|
||||
const int64_t attn_elems = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
|
||||
const int64_t state_elems = head_v_dim * head_v_dim * num_v_heads * n_seqs;
|
||||
|
||||
output = ggml_view_4d(ctx0, result, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
|
||||
head_v_dim * sizeof(float),
|
||||
head_v_dim * num_v_heads * sizeof(float),
|
||||
head_v_dim * num_v_heads * n_seq_tokens * sizeof(float),
|
||||
0);
|
||||
new_state = ggml_view_1d(ctx0, result, state_elems, attn_elems * sizeof(float));
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out;
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
output = attn_out.first;
|
||||
new_state = attn_out.second;
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
}
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
|
|
|
|||
|
|
@ -3635,6 +3635,35 @@ struct test_rwkv_wkv6 : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GATED_DELTA_NET
|
||||
struct test_gated_delta_net : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
const int64_t head_size;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_gated_delta_net(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type, head_size, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * g = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_3d(ctx, type, head_count, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_2d(ctx, type, head_size * head_size * head_count, n_seqs);
|
||||
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state, 1e-6f);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GATED_LINEAR_ATTN
|
||||
struct test_gla : public test_case {
|
||||
const ggml_type type;
|
||||
|
|
@ -8310,6 +8339,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
|
|
|
|||
Loading…
Reference in New Issue