CANN: add GATED_DELTA_NET op support

This commit is contained in:
hipudding 2026-03-27 08:52:24 +00:00
parent c0e78773e9
commit 140c5a3d1b
3 changed files with 346 additions and 0 deletions

View File

@ -62,6 +62,7 @@
#include <aclnnop/aclnn_permute.h>
#include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_pow_tensor_tensor.h>
#include <aclnnop/aclnn_recurrent_gated_delta_rule.h>
#include <aclnnop/aclnn_reduce_sum.h>
#include <aclnnop/aclnn_reflection_pad1d.h>
#include <aclnnop/aclnn_repeat.h>
@ -4187,3 +4188,323 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor *
}
}
}
static void ggml_cann_gated_delta_net_naive(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
const int64_t S_v = src_v->ne[0];
const int64_t H = src_v->ne[1];
const int64_t n_tokens = src_v->ne[2];
const int64_t n_seqs = src_v->ne[3];
const int64_t neq1 = src_q->ne[1];
const int64_t nek1 = src_k->ne[1];
const int64_t neq3 = src_q->ne[3];
const int64_t nek3 = src_k->ne[3];
const bool kda = (src_g->ne[0] == S_v);
const float scale = 1.0f / sqrtf((float) S_v);
const size_t nb_f32 = sizeof(float);
// Q/K strides (may differ from V for GQA)
const size_t nbq1 = src_q->nb[1], nbq2 = src_q->nb[2], nbq3 = src_q->nb[3];
const size_t nbk1 = src_k->nb[1], nbk2 = src_k->nb[2], nbk3 = src_k->nb[3];
const size_t nbv1 = src_v->nb[1], nbv2 = src_v->nb[2], nbv3 = src_v->nb[3];
const size_t nbg1 = src_g->nb[1], nbg2 = src_g->nb[2], nbg3 = src_g->nb[3];
const size_t nbb1 = src_beta->nb[1], nbb2 = src_beta->nb[2], nbb3 = src_beta->nb[3];
const int64_t rq3 = (neq3 > 0) ? n_seqs / neq3 : 1;
const int64_t rk3 = (nek3 > 0) ? n_seqs / nek3 : 1;
// Output layout: [attn_scores | new_states]
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
const size_t state_out_offset = attn_score_elems * nb_f32;
// Shapes for per-head operations
int64_t ne_s[2] = { S_v, S_v };
size_t nb_s[2] = { nb_f32, S_v * nb_f32 };
int64_t ne_vec[1] = { S_v };
size_t nb_vec[1] = { nb_f32 };
int64_t ne_sc[1] = { 1 };
size_t nb_sc[1] = { nb_f32 };
int64_t ne_g_bc[2] = { S_v, 1 }; // for KDA gate broadcast
size_t nb_g_bc[2] = { nb_f32, S_v * nb_f32 };
// Copy input state to output state area
{
int64_t ne_flat[1] = { S_v * S_v * H * n_seqs };
size_t nb_flat[1] = { nb_f32 };
acl_tensor_ptr acl_sin = ggml_cann_create_tensor(
src_state->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1);
acl_tensor_ptr acl_sout = ggml_cann_create_tensor(
dst->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1, ACL_FORMAT_ND, state_out_offset);
cann_copy(ctx, acl_sin.get(), acl_sout.get());
}
for (int64_t s = 0; s < n_seqs; s++) {
for (int64_t h = 0; h < H; h++) {
const size_t s_off = state_out_offset + ((s * H + h) * S_v * S_v) * nb_f32;
const int64_t iq1 = h % neq1;
const int64_t ik1 = h % nek1;
const int64_t iq3 = s / rq3;
const int64_t ik3 = s / rk3;
for (int64_t t = 0; t < n_tokens; t++) {
// State matrix view [S_v, S_v] (transposed storage: M[j][i] = S[i][j])
// Mv(M, k) = S^T @ k, Mv(M, q) = S^T @ q
acl_tensor_ptr acl_s_mat = ggml_cann_create_tensor(
dst->data, ACL_FLOAT, nb_f32, ne_s, nb_s, 2, ACL_FORMAT_ND, s_off);
// Input tensor views
const size_t q_off = iq3 * nbq3 + t * nbq2 + iq1 * nbq1;
const size_t k_off = ik3 * nbk3 + t * nbk2 + ik1 * nbk1;
const size_t v_off = s * nbv3 + t * nbv2 + h * nbv1;
const size_t beta_off = s * nbb3 + t * nbb2 + h * nbb1;
const size_t g_off = s * nbg3 + t * nbg2 + h * nbg1;
acl_tensor_ptr acl_q = ggml_cann_create_tensor(
src_q->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, q_off);
acl_tensor_ptr acl_k = ggml_cann_create_tensor(
src_k->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, k_off);
acl_tensor_ptr acl_v = ggml_cann_create_tensor(
src_v->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, v_off);
acl_tensor_ptr acl_beta = ggml_cann_create_tensor(
src_beta->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, beta_off);
// Step 1: State decay S *= exp(g)
if (kda) {
// KDA mode: M[j][i] *= exp(g[i]) for all j
ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), S_v * nb_f32);
acl_tensor_ptr acl_g_src = ggml_cann_create_tensor(
src_g->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, g_off);
acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor(
g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1);
cann_copy(ctx, acl_g_src.get(), acl_g_exp.get());
aclnn_exp(ctx, acl_g_exp.get());
// Broadcast: GGML [S_v,1] -> CANN [1,S_v] broadcasts along rows
acl_tensor_ptr acl_g_exp_bc = ggml_cann_create_tensor(
g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_g_bc, nb_g_bc, 2);
aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp_bc.get(), nullptr);
} else {
// Scalar mode: M *= exp(g[0])
ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), nb_f32);
acl_tensor_ptr acl_g_src = ggml_cann_create_tensor(
src_g->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, g_off);
acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor(
g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1);
cann_copy(ctx, acl_g_src.get(), acl_g_exp.get());
aclnn_exp(ctx, acl_g_exp.get());
aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp.get(), nullptr);
}
// Step 2: delta = (v - M @ k) * beta
ggml_cann_pool_alloc kv_alloc(ctx.pool(), S_v * nb_f32);
acl_tensor_ptr acl_kv = ggml_cann_create_tensor(
kv_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1);
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_k.get(), acl_kv.get(), 1);
ggml_cann_pool_alloc delta_alloc(ctx.pool(), S_v * nb_f32);
acl_tensor_ptr acl_delta = ggml_cann_create_tensor(
delta_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1);
aclnn_sub(ctx, acl_v.get(), acl_kv.get(), acl_delta.get());
aclnn_mul(ctx, acl_delta.get(), acl_beta.get(), nullptr);
// Step 3: State update M += delta ⊗ k (outer product)
ggml_cann_pool_alloc outer_alloc(ctx.pool(), S_v * S_v * nb_f32);
acl_tensor_ptr acl_outer = ggml_cann_create_tensor(
outer_alloc.get(), ACL_FLOAT, nb_f32, ne_s, nb_s, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_delta.get(), acl_k.get(), acl_outer.get());
aclnn_add(ctx, acl_s_mat.get(), acl_outer.get(), nullptr);
// Step 4: Output attn = M @ q * scale
const size_t attn_off = ((s * n_tokens * H + t * H + h) * S_v) * nb_f32;
acl_tensor_ptr acl_attn = ggml_cann_create_tensor(
dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, attn_off);
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_q.get(), acl_attn.get(), 1);
aclnn_muls(ctx, acl_attn.get(), scale, nullptr, true);
}
}
}
}
void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src_q = dst->src[0];
ggml_tensor * src_k = dst->src[1];
ggml_tensor * src_v = dst->src[2];
ggml_tensor * src_g = dst->src[3];
ggml_tensor * src_beta = dst->src[4];
ggml_tensor * src_state = dst->src[5];
const int64_t S_v = src_v->ne[0];
const int64_t H = src_v->ne[1];
const int64_t n_tokens = src_v->ne[2];
const int64_t n_seqs = src_v->ne[3];
const int64_t neq1 = src_q->ne[1];
const int64_t nek1 = src_k->ne[1];
const int64_t neq3 = src_q->ne[3];
const int64_t nek3 = src_k->ne[3];
const bool kda = (src_g->ne[0] == S_v);
// Check if we can use the fused aclnnRecurrentGatedDeltaRule operator.
// Constraints from the operator spec:
// - gk (KDA mode) is not supported in current CANN version
// - Li <= 8 (per-sequence token count)
// - Nk, Nv, Dk, Dv <= 256
// - Q/K/V head counts must match (Nk == Nv, no GQA)
// - No batch-dimension broadcasting for Q/K
// - Input tensors must be contiguous (fused path assumes contiguous layout)
const bool use_fused = !kda
&& n_tokens <= 8
&& S_v <= 256
&& H <= 256
&& neq1 <= 256
&& neq1 == nek1
&& neq1 == H
&& neq3 == n_seqs
&& nek3 == n_seqs
&& ggml_is_contiguous(src_q)
&& ggml_is_contiguous(src_k)
&& ggml_is_contiguous(src_v);
if (!use_fused) {
ggml_cann_gated_delta_net_naive(ctx, dst);
return;
}
const int64_t T = n_seqs * n_tokens;
const float scale = 1.0f / sqrtf((float) S_v);
const size_t nb_f32 = sizeof(float);
const size_t nb_bf16 = sizeof(uint16_t);
const size_t nb_i32 = sizeof(int32_t);
// Output layout: [attn_scores | new_states]
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
const size_t state_out_offset = attn_score_elems * nb_f32;
// ---- Cast F32 inputs to BF16 ----
// Q: GGML [S_v, neq1, n_tokens, n_seqs] → 3D [S_v, neq1, T] → CANN (T, Nk, Dk)
int64_t ne_q[3] = { S_v, neq1, T };
size_t nb_q_f32[3] = { nb_f32, S_v * nb_f32, S_v * neq1 * nb_f32 };
size_t nb_q_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * neq1 * nb_bf16 };
ggml_cann_pool_alloc q_bf16_alloc(ctx.pool(), T * neq1 * S_v * nb_bf16);
acl_tensor_ptr acl_q_f32 = ggml_cann_create_tensor(src_q->data, ACL_FLOAT, nb_f32, ne_q, nb_q_f32, 3);
acl_tensor_ptr acl_q_bf16 = ggml_cann_create_tensor(q_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_q, nb_q_bf16, 3);
aclnn_cast(ctx, acl_q_f32.get(), acl_q_bf16.get(), ACL_BF16);
// K: GGML [S_v, nek1, n_tokens, n_seqs] → 3D [S_v, nek1, T] → CANN (T, Nk, Dk)
int64_t ne_k[3] = { S_v, nek1, T };
size_t nb_k_f32[3] = { nb_f32, S_v * nb_f32, S_v * nek1 * nb_f32 };
size_t nb_k_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * nek1 * nb_bf16 };
ggml_cann_pool_alloc k_bf16_alloc(ctx.pool(), T * nek1 * S_v * nb_bf16);
acl_tensor_ptr acl_k_f32 = ggml_cann_create_tensor(src_k->data, ACL_FLOAT, nb_f32, ne_k, nb_k_f32, 3);
acl_tensor_ptr acl_k_bf16 = ggml_cann_create_tensor(k_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_k, nb_k_bf16, 3);
aclnn_cast(ctx, acl_k_f32.get(), acl_k_bf16.get(), ACL_BF16);
// V: GGML [S_v, H, n_tokens, n_seqs] → 3D [S_v, H, T] → CANN (T, Nv, Dv)
int64_t ne_v[3] = { S_v, H, T };
size_t nb_v_f32[3] = { nb_f32, S_v * nb_f32, S_v * H * nb_f32 };
size_t nb_v_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * H * nb_bf16 };
ggml_cann_pool_alloc v_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16);
acl_tensor_ptr acl_v_f32 = ggml_cann_create_tensor(src_v->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3);
acl_tensor_ptr acl_v_bf16 = ggml_cann_create_tensor(v_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3);
aclnn_cast(ctx, acl_v_f32.get(), acl_v_bf16.get(), ACL_BF16);
// Beta: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv)
int64_t ne_hf[2] = { H, T };
size_t nb_hf_f32[2] = { nb_f32, H * nb_f32 };
size_t nb_hf_bf16[2] = { nb_bf16, H * nb_bf16 };
ggml_cann_pool_alloc beta_bf16_alloc(ctx.pool(), T * H * nb_bf16);
acl_tensor_ptr acl_beta_f32 = ggml_cann_create_tensor(src_beta->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2);
acl_tensor_ptr acl_beta_bf16 = ggml_cann_create_tensor(beta_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_hf, nb_hf_bf16, 2);
aclnn_cast(ctx, acl_beta_f32.get(), acl_beta_bf16.get(), ACL_BF16);
// Gate g: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv), stays F32
acl_tensor_ptr acl_g = ggml_cann_create_tensor(src_g->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2);
// State: GGML [S_v, S_v, H, n_seqs] → CANN (BlockNum, Nv, Dv, Dk)
// GGML stores M = S^T, but the recurrence applied to M has the same form as the
// standard delta rule, so M can be passed directly as the API's state parameter.
const int64_t state_elems = n_seqs * H * S_v * S_v;
int64_t ne_st[4] = { S_v, S_v, H, n_seqs };
size_t nb_st_f32[4] = { nb_f32, S_v * nb_f32, S_v * S_v * nb_f32, S_v * S_v * H * nb_f32 };
size_t nb_st_bf16[4] = { nb_bf16, S_v * nb_bf16, S_v * S_v * nb_bf16, S_v * S_v * H * nb_bf16 };
ggml_cann_pool_alloc state_bf16_alloc(ctx.pool(), state_elems * nb_bf16);
acl_tensor_ptr acl_state_f32 = ggml_cann_create_tensor(src_state->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4);
acl_tensor_ptr acl_state_bf16 = ggml_cann_create_tensor(state_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_st, nb_st_bf16, 4);
aclnn_cast(ctx, acl_state_f32.get(), acl_state_bf16.get(), ACL_BF16);
// Output buffer in BF16: (T, Nv, Dv) — same layout as V
ggml_cann_pool_alloc out_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16);
acl_tensor_ptr acl_out_bf16 = ggml_cann_create_tensor(out_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3);
// ---- Prepare INT32 auxiliary tensors ----
// actualSeqLengths: (B,) — each sequence has n_tokens tokens
std::vector<int32_t> host_seq_lens(n_seqs, (int32_t) n_tokens);
ggml_cann_pool_alloc asl_alloc(ctx.pool(), n_seqs * nb_i32);
ACL_CHECK(aclrtMemcpy(asl_alloc.get(), n_seqs * nb_i32,
host_seq_lens.data(), n_seqs * nb_i32,
ACL_MEMCPY_HOST_TO_DEVICE));
int64_t ne_b[1] = { n_seqs };
size_t nb_b[1] = { nb_i32 };
acl_tensor_ptr acl_asl = ggml_cann_create_tensor(asl_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1);
// ssmStateIndices: (T,) — token at seq s, pos t maps to state block s
std::vector<int32_t> host_ssm_idx(T);
for (int64_t s = 0; s < n_seqs; s++) {
for (int64_t t = 0; t < n_tokens; t++) {
host_ssm_idx[s * n_tokens + t] = (int32_t) s;
}
}
ggml_cann_pool_alloc ssm_alloc(ctx.pool(), T * nb_i32);
ACL_CHECK(aclrtMemcpy(ssm_alloc.get(), T * nb_i32,
host_ssm_idx.data(), T * nb_i32,
ACL_MEMCPY_HOST_TO_DEVICE));
int64_t ne_T[1] = { T };
size_t nb_T[1] = { nb_i32 };
acl_tensor_ptr acl_ssm = ggml_cann_create_tensor(ssm_alloc.get(), ACL_INT32, nb_i32, ne_T, nb_T, 1);
// numAcceptedTokens: (B,) — all tokens are accepted
ggml_cann_pool_alloc nat_alloc(ctx.pool(), n_seqs * nb_i32);
ACL_CHECK(aclrtMemcpy(nat_alloc.get(), n_seqs * nb_i32,
host_seq_lens.data(), n_seqs * nb_i32,
ACL_MEMCPY_HOST_TO_DEVICE));
acl_tensor_ptr acl_nat = ggml_cann_create_tensor(nat_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1);
// ---- Call fused operator ----
GGML_CANN_CALL_ACLNN_OP(ctx, RecurrentGatedDeltaRule,
acl_q_bf16.get(), acl_k_bf16.get(), acl_v_bf16.get(),
acl_beta_bf16.get(), acl_state_bf16.get(),
acl_asl.get(), acl_ssm.get(),
acl_g.get(), nullptr, acl_nat.get(),
scale, acl_out_bf16.get());
// ---- Cast BF16 outputs back to F32 ----
// Attention output → dst[0 .. state_out_offset)
acl_tensor_ptr acl_dst_attn = ggml_cann_create_tensor(
dst->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3);
aclnn_cast(ctx, acl_out_bf16.get(), acl_dst_attn.get(), ACL_FLOAT);
// Updated state → dst[state_out_offset .. end)
acl_tensor_ptr acl_dst_state = ggml_cann_create_tensor(
dst->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4, ACL_FORMAT_ND, state_out_offset);
aclnn_cast(ctx, acl_state_bf16.get(), acl_dst_state.get(), ACL_FLOAT);
}

View File

@ -847,6 +847,27 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
*/
void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Forward Gated Delta Net on the CANN backend.
*
* Expects dst->src[0..5] = {q, k, v, g, beta, state} with shape conventions:
* q, k: [S_v, H_q/H_k, n_tokens, n_seqs] (contiguous rows)
* v: [S_v, H, n_tokens, n_seqs]
* g: [1, H, n_tokens, n_seqs] (scalar gate) or [S_v, H, n_tokens, n_seqs] (KDA)
* beta: [1, H, n_tokens, n_seqs]
* state:[S_v, S_v, H, n_seqs]
*
* Per token recurrence:
* S_t = exp(g_t) * S_{t-1} + k_t * (v_t - S_{t-1}^T k_t)^T * beta_t
* out_t = S_t^T q_t / sqrt(S_v)
*
* dst holds both attention outputs and updated state.
*
* @param ctx Backend context providing stream/allocator utilities.
* @param dst Output tensor; src deps are q, k, v, g, beta, state as above.
*/
void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Launches an asynchronous task using the memory allocator.
*

View File

@ -1905,6 +1905,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_SSM_CONV:
ggml_cann_ssm_conv(ctx, dst);
break;
case GGML_OP_GATED_DELTA_NET:
ggml_cann_gated_delta_net(ctx, dst);
break;
default:
return false;
}
@ -2565,6 +2568,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return true;
}
case GGML_OP_SSM_CONV:
case GGML_OP_GATED_DELTA_NET:
return true;
default:
return false;