From 140c5a3d1bae3237eb582f30ddcd6d8a74e74174 Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 27 Mar 2026 08:52:24 +0000 Subject: [PATCH] CANN: add GATED_DELTA_NET op support --- ggml/src/ggml-cann/aclnn_ops.cpp | 321 +++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 21 ++ ggml/src/ggml-cann/ggml-cann.cpp | 4 + 3 files changed, 346 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 315bea0c84..02cb844c63 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -62,6 +62,7 @@ #include #include #include +#include #include #include #include @@ -4187,3 +4188,323 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * } } } + +static void ggml_cann_gated_delta_net_naive(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; + + const int64_t neq1 = src_q->ne[1]; + const int64_t nek1 = src_k->ne[1]; + const int64_t neq3 = src_q->ne[3]; + const int64_t nek3 = src_k->ne[3]; + + const bool kda = (src_g->ne[0] == S_v); + + const float scale = 1.0f / sqrtf((float) S_v); + + const size_t nb_f32 = sizeof(float); + + // Q/K strides (may differ from V for GQA) + const size_t nbq1 = src_q->nb[1], nbq2 = src_q->nb[2], nbq3 = src_q->nb[3]; + const size_t nbk1 = src_k->nb[1], nbk2 = src_k->nb[2], nbk3 = src_k->nb[3]; + const size_t nbv1 = src_v->nb[1], nbv2 = src_v->nb[2], nbv3 = src_v->nb[3]; + const size_t nbg1 = src_g->nb[1], nbg2 = src_g->nb[2], nbg3 = src_g->nb[3]; + const size_t nbb1 = src_beta->nb[1], nbb2 = src_beta->nb[2], nbb3 = src_beta->nb[3]; + + const int64_t rq3 = (neq3 > 0) ? n_seqs / neq3 : 1; + const int64_t rk3 = (nek3 > 0) ? n_seqs / nek3 : 1; + + // Output layout: [attn_scores | new_states] + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const size_t state_out_offset = attn_score_elems * nb_f32; + + // Shapes for per-head operations + int64_t ne_s[2] = { S_v, S_v }; + size_t nb_s[2] = { nb_f32, S_v * nb_f32 }; + int64_t ne_vec[1] = { S_v }; + size_t nb_vec[1] = { nb_f32 }; + int64_t ne_sc[1] = { 1 }; + size_t nb_sc[1] = { nb_f32 }; + int64_t ne_g_bc[2] = { S_v, 1 }; // for KDA gate broadcast + size_t nb_g_bc[2] = { nb_f32, S_v * nb_f32 }; + + // Copy input state to output state area + { + int64_t ne_flat[1] = { S_v * S_v * H * n_seqs }; + size_t nb_flat[1] = { nb_f32 }; + acl_tensor_ptr acl_sin = ggml_cann_create_tensor( + src_state->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1); + acl_tensor_ptr acl_sout = ggml_cann_create_tensor( + dst->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1, ACL_FORMAT_ND, state_out_offset); + cann_copy(ctx, acl_sin.get(), acl_sout.get()); + } + + for (int64_t s = 0; s < n_seqs; s++) { + for (int64_t h = 0; h < H; h++) { + const size_t s_off = state_out_offset + ((s * H + h) * S_v * S_v) * nb_f32; + + const int64_t iq1 = h % neq1; + const int64_t ik1 = h % nek1; + const int64_t iq3 = s / rq3; + const int64_t ik3 = s / rk3; + + for (int64_t t = 0; t < n_tokens; t++) { + // State matrix view [S_v, S_v] (transposed storage: M[j][i] = S[i][j]) + // Mv(M, k) = S^T @ k, Mv(M, q) = S^T @ q + acl_tensor_ptr acl_s_mat = ggml_cann_create_tensor( + dst->data, ACL_FLOAT, nb_f32, ne_s, nb_s, 2, ACL_FORMAT_ND, s_off); + + // Input tensor views + const size_t q_off = iq3 * nbq3 + t * nbq2 + iq1 * nbq1; + const size_t k_off = ik3 * nbk3 + t * nbk2 + ik1 * nbk1; + const size_t v_off = s * nbv3 + t * nbv2 + h * nbv1; + const size_t beta_off = s * nbb3 + t * nbb2 + h * nbb1; + const size_t g_off = s * nbg3 + t * nbg2 + h * nbg1; + + acl_tensor_ptr acl_q = ggml_cann_create_tensor( + src_q->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, q_off); + acl_tensor_ptr acl_k = ggml_cann_create_tensor( + src_k->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, k_off); + acl_tensor_ptr acl_v = ggml_cann_create_tensor( + src_v->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, v_off); + acl_tensor_ptr acl_beta = ggml_cann_create_tensor( + src_beta->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, beta_off); + + // Step 1: State decay S *= exp(g) + if (kda) { + // KDA mode: M[j][i] *= exp(g[i]) for all j + ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), S_v * nb_f32); + acl_tensor_ptr acl_g_src = ggml_cann_create_tensor( + src_g->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, g_off); + acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor( + g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1); + cann_copy(ctx, acl_g_src.get(), acl_g_exp.get()); + aclnn_exp(ctx, acl_g_exp.get()); + // Broadcast: GGML [S_v,1] -> CANN [1,S_v] broadcasts along rows + acl_tensor_ptr acl_g_exp_bc = ggml_cann_create_tensor( + g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_g_bc, nb_g_bc, 2); + aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp_bc.get(), nullptr); + } else { + // Scalar mode: M *= exp(g[0]) + ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), nb_f32); + acl_tensor_ptr acl_g_src = ggml_cann_create_tensor( + src_g->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, g_off); + acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor( + g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1); + cann_copy(ctx, acl_g_src.get(), acl_g_exp.get()); + aclnn_exp(ctx, acl_g_exp.get()); + aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp.get(), nullptr); + } + + // Step 2: delta = (v - M @ k) * beta + ggml_cann_pool_alloc kv_alloc(ctx.pool(), S_v * nb_f32); + acl_tensor_ptr acl_kv = ggml_cann_create_tensor( + kv_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1); + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_k.get(), acl_kv.get(), 1); + + ggml_cann_pool_alloc delta_alloc(ctx.pool(), S_v * nb_f32); + acl_tensor_ptr acl_delta = ggml_cann_create_tensor( + delta_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1); + aclnn_sub(ctx, acl_v.get(), acl_kv.get(), acl_delta.get()); + aclnn_mul(ctx, acl_delta.get(), acl_beta.get(), nullptr); + + // Step 3: State update M += delta ⊗ k (outer product) + ggml_cann_pool_alloc outer_alloc(ctx.pool(), S_v * S_v * nb_f32); + acl_tensor_ptr acl_outer = ggml_cann_create_tensor( + outer_alloc.get(), ACL_FLOAT, nb_f32, ne_s, nb_s, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_delta.get(), acl_k.get(), acl_outer.get()); + aclnn_add(ctx, acl_s_mat.get(), acl_outer.get(), nullptr); + + // Step 4: Output attn = M @ q * scale + const size_t attn_off = ((s * n_tokens * H + t * H + h) * S_v) * nb_f32; + acl_tensor_ptr acl_attn = ggml_cann_create_tensor( + dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, attn_off); + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_q.get(), acl_attn.get(), 1); + aclnn_muls(ctx, acl_attn.get(), scale, nullptr, true); + } + } + } +} + +void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; + + const int64_t neq1 = src_q->ne[1]; + const int64_t nek1 = src_k->ne[1]; + const int64_t neq3 = src_q->ne[3]; + const int64_t nek3 = src_k->ne[3]; + + const bool kda = (src_g->ne[0] == S_v); + + // Check if we can use the fused aclnnRecurrentGatedDeltaRule operator. + // Constraints from the operator spec: + // - gk (KDA mode) is not supported in current CANN version + // - Li <= 8 (per-sequence token count) + // - Nk, Nv, Dk, Dv <= 256 + // - Q/K/V head counts must match (Nk == Nv, no GQA) + // - No batch-dimension broadcasting for Q/K + // - Input tensors must be contiguous (fused path assumes contiguous layout) + const bool use_fused = !kda + && n_tokens <= 8 + && S_v <= 256 + && H <= 256 + && neq1 <= 256 + && neq1 == nek1 + && neq1 == H + && neq3 == n_seqs + && nek3 == n_seqs + && ggml_is_contiguous(src_q) + && ggml_is_contiguous(src_k) + && ggml_is_contiguous(src_v); + + if (!use_fused) { + ggml_cann_gated_delta_net_naive(ctx, dst); + return; + } + + const int64_t T = n_seqs * n_tokens; + const float scale = 1.0f / sqrtf((float) S_v); + + const size_t nb_f32 = sizeof(float); + const size_t nb_bf16 = sizeof(uint16_t); + const size_t nb_i32 = sizeof(int32_t); + + // Output layout: [attn_scores | new_states] + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const size_t state_out_offset = attn_score_elems * nb_f32; + + // ---- Cast F32 inputs to BF16 ---- + + // Q: GGML [S_v, neq1, n_tokens, n_seqs] → 3D [S_v, neq1, T] → CANN (T, Nk, Dk) + int64_t ne_q[3] = { S_v, neq1, T }; + size_t nb_q_f32[3] = { nb_f32, S_v * nb_f32, S_v * neq1 * nb_f32 }; + size_t nb_q_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * neq1 * nb_bf16 }; + + ggml_cann_pool_alloc q_bf16_alloc(ctx.pool(), T * neq1 * S_v * nb_bf16); + acl_tensor_ptr acl_q_f32 = ggml_cann_create_tensor(src_q->data, ACL_FLOAT, nb_f32, ne_q, nb_q_f32, 3); + acl_tensor_ptr acl_q_bf16 = ggml_cann_create_tensor(q_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_q, nb_q_bf16, 3); + aclnn_cast(ctx, acl_q_f32.get(), acl_q_bf16.get(), ACL_BF16); + + // K: GGML [S_v, nek1, n_tokens, n_seqs] → 3D [S_v, nek1, T] → CANN (T, Nk, Dk) + int64_t ne_k[3] = { S_v, nek1, T }; + size_t nb_k_f32[3] = { nb_f32, S_v * nb_f32, S_v * nek1 * nb_f32 }; + size_t nb_k_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * nek1 * nb_bf16 }; + + ggml_cann_pool_alloc k_bf16_alloc(ctx.pool(), T * nek1 * S_v * nb_bf16); + acl_tensor_ptr acl_k_f32 = ggml_cann_create_tensor(src_k->data, ACL_FLOAT, nb_f32, ne_k, nb_k_f32, 3); + acl_tensor_ptr acl_k_bf16 = ggml_cann_create_tensor(k_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_k, nb_k_bf16, 3); + aclnn_cast(ctx, acl_k_f32.get(), acl_k_bf16.get(), ACL_BF16); + + // V: GGML [S_v, H, n_tokens, n_seqs] → 3D [S_v, H, T] → CANN (T, Nv, Dv) + int64_t ne_v[3] = { S_v, H, T }; + size_t nb_v_f32[3] = { nb_f32, S_v * nb_f32, S_v * H * nb_f32 }; + size_t nb_v_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * H * nb_bf16 }; + + ggml_cann_pool_alloc v_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16); + acl_tensor_ptr acl_v_f32 = ggml_cann_create_tensor(src_v->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3); + acl_tensor_ptr acl_v_bf16 = ggml_cann_create_tensor(v_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3); + aclnn_cast(ctx, acl_v_f32.get(), acl_v_bf16.get(), ACL_BF16); + + // Beta: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv) + int64_t ne_hf[2] = { H, T }; + size_t nb_hf_f32[2] = { nb_f32, H * nb_f32 }; + size_t nb_hf_bf16[2] = { nb_bf16, H * nb_bf16 }; + + ggml_cann_pool_alloc beta_bf16_alloc(ctx.pool(), T * H * nb_bf16); + acl_tensor_ptr acl_beta_f32 = ggml_cann_create_tensor(src_beta->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2); + acl_tensor_ptr acl_beta_bf16 = ggml_cann_create_tensor(beta_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_hf, nb_hf_bf16, 2); + aclnn_cast(ctx, acl_beta_f32.get(), acl_beta_bf16.get(), ACL_BF16); + + // Gate g: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv), stays F32 + acl_tensor_ptr acl_g = ggml_cann_create_tensor(src_g->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2); + + // State: GGML [S_v, S_v, H, n_seqs] → CANN (BlockNum, Nv, Dv, Dk) + // GGML stores M = S^T, but the recurrence applied to M has the same form as the + // standard delta rule, so M can be passed directly as the API's state parameter. + const int64_t state_elems = n_seqs * H * S_v * S_v; + int64_t ne_st[4] = { S_v, S_v, H, n_seqs }; + size_t nb_st_f32[4] = { nb_f32, S_v * nb_f32, S_v * S_v * nb_f32, S_v * S_v * H * nb_f32 }; + size_t nb_st_bf16[4] = { nb_bf16, S_v * nb_bf16, S_v * S_v * nb_bf16, S_v * S_v * H * nb_bf16 }; + + ggml_cann_pool_alloc state_bf16_alloc(ctx.pool(), state_elems * nb_bf16); + acl_tensor_ptr acl_state_f32 = ggml_cann_create_tensor(src_state->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4); + acl_tensor_ptr acl_state_bf16 = ggml_cann_create_tensor(state_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_st, nb_st_bf16, 4); + aclnn_cast(ctx, acl_state_f32.get(), acl_state_bf16.get(), ACL_BF16); + + // Output buffer in BF16: (T, Nv, Dv) — same layout as V + ggml_cann_pool_alloc out_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16); + acl_tensor_ptr acl_out_bf16 = ggml_cann_create_tensor(out_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3); + + // ---- Prepare INT32 auxiliary tensors ---- + + // actualSeqLengths: (B,) — each sequence has n_tokens tokens + std::vector host_seq_lens(n_seqs, (int32_t) n_tokens); + ggml_cann_pool_alloc asl_alloc(ctx.pool(), n_seqs * nb_i32); + ACL_CHECK(aclrtMemcpy(asl_alloc.get(), n_seqs * nb_i32, + host_seq_lens.data(), n_seqs * nb_i32, + ACL_MEMCPY_HOST_TO_DEVICE)); + int64_t ne_b[1] = { n_seqs }; + size_t nb_b[1] = { nb_i32 }; + acl_tensor_ptr acl_asl = ggml_cann_create_tensor(asl_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1); + + // ssmStateIndices: (T,) — token at seq s, pos t maps to state block s + std::vector host_ssm_idx(T); + for (int64_t s = 0; s < n_seqs; s++) { + for (int64_t t = 0; t < n_tokens; t++) { + host_ssm_idx[s * n_tokens + t] = (int32_t) s; + } + } + ggml_cann_pool_alloc ssm_alloc(ctx.pool(), T * nb_i32); + ACL_CHECK(aclrtMemcpy(ssm_alloc.get(), T * nb_i32, + host_ssm_idx.data(), T * nb_i32, + ACL_MEMCPY_HOST_TO_DEVICE)); + int64_t ne_T[1] = { T }; + size_t nb_T[1] = { nb_i32 }; + acl_tensor_ptr acl_ssm = ggml_cann_create_tensor(ssm_alloc.get(), ACL_INT32, nb_i32, ne_T, nb_T, 1); + + // numAcceptedTokens: (B,) — all tokens are accepted + ggml_cann_pool_alloc nat_alloc(ctx.pool(), n_seqs * nb_i32); + ACL_CHECK(aclrtMemcpy(nat_alloc.get(), n_seqs * nb_i32, + host_seq_lens.data(), n_seqs * nb_i32, + ACL_MEMCPY_HOST_TO_DEVICE)); + acl_tensor_ptr acl_nat = ggml_cann_create_tensor(nat_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1); + + // ---- Call fused operator ---- + GGML_CANN_CALL_ACLNN_OP(ctx, RecurrentGatedDeltaRule, + acl_q_bf16.get(), acl_k_bf16.get(), acl_v_bf16.get(), + acl_beta_bf16.get(), acl_state_bf16.get(), + acl_asl.get(), acl_ssm.get(), + acl_g.get(), nullptr, acl_nat.get(), + scale, acl_out_bf16.get()); + + // ---- Cast BF16 outputs back to F32 ---- + + // Attention output → dst[0 .. state_out_offset) + acl_tensor_ptr acl_dst_attn = ggml_cann_create_tensor( + dst->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3); + aclnn_cast(ctx, acl_out_bf16.get(), acl_dst_attn.get(), ACL_FLOAT); + + // Updated state → dst[state_out_offset .. end) + acl_tensor_ptr acl_dst_state = ggml_cann_create_tensor( + dst->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4, ACL_FORMAT_ND, state_out_offset); + aclnn_cast(ctx, acl_state_bf16.get(), acl_dst_state.get(), ACL_FLOAT); +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a68e9119ae..19d1d65bf0 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -847,6 +847,27 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst */ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Forward Gated Delta Net on the CANN backend. + * + * Expects dst->src[0..5] = {q, k, v, g, beta, state} with shape conventions: + * q, k: [S_v, H_q/H_k, n_tokens, n_seqs] (contiguous rows) + * v: [S_v, H, n_tokens, n_seqs] + * g: [1, H, n_tokens, n_seqs] (scalar gate) or [S_v, H, n_tokens, n_seqs] (KDA) + * beta: [1, H, n_tokens, n_seqs] + * state:[S_v, S_v, H, n_seqs] + * + * Per token recurrence: + * S_t = exp(g_t) * S_{t-1} + k_t * (v_t - S_{t-1}^T k_t)^T * beta_t + * out_t = S_t^T q_t / sqrt(S_v) + * + * dst holds both attention outputs and updated state. + * + * @param ctx Backend context providing stream/allocator utilities. + * @param dst Output tensor; src deps are q, k, v, g, beta, state as above. + */ +void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Launches an asynchronous task using the memory allocator. * diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index f768de4d86..eabd2d4979 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1905,6 +1905,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_SSM_CONV: ggml_cann_ssm_conv(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_cann_gated_delta_net(ctx, dst); + break; default: return false; } @@ -2565,6 +2568,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return true; } case GGML_OP_SSM_CONV: + case GGML_OP_GATED_DELTA_NET: return true; default: return false;