From 11e78d8499a0e7f22c062c9a0b0c61c41713e9e0 Mon Sep 17 00:00:00 2001 From: hipudding Date: Sat, 28 Mar 2026 01:02:32 +0000 Subject: [PATCH] CANN: simplify GATED_DELTA_NET implementation - Remove dead code: _math and _naive variants are no longer needed - Rename _batched to the public entry point ggml_cann_gated_delta_net - In supports_op, return false for non-contiguous / GQA / non-F32 cases so the framework falls back to CPU instead of running the slow naive path - The single remaining implementation uses aclnnBatchMatMul over all H heads per timestep, reducing kernel launches to O(n_seqs * n_tokens) --- ggml/src/ggml-cann/aclnn_ops.cpp | 406 +------------------------------ ggml/src/ggml-cann/ggml-cann.cpp | 23 +- 2 files changed, 24 insertions(+), 405 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index f0471d2a6b..8287cadeae 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -4189,235 +4189,7 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * } } -static void ggml_cann_gated_delta_net_naive(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src_q = dst->src[0]; - ggml_tensor * src_k = dst->src[1]; - ggml_tensor * src_v = dst->src[2]; - ggml_tensor * src_g = dst->src[3]; - ggml_tensor * src_beta = dst->src[4]; - ggml_tensor * src_state = dst->src[5]; - - const int64_t S_v = src_v->ne[0]; - const int64_t H = src_v->ne[1]; - const int64_t n_tokens = src_v->ne[2]; - const int64_t n_seqs = src_v->ne[3]; - - const int64_t neq1 = src_q->ne[1]; - const int64_t nek1 = src_k->ne[1]; - const int64_t neq3 = src_q->ne[3]; - const int64_t nek3 = src_k->ne[3]; - - const bool kda = (src_g->ne[0] == S_v); - - const float scale = 1.0f / sqrtf((float) S_v); - - const size_t nb_f32 = sizeof(float); - - // Q/K strides (may differ from V for GQA) - const size_t nbq1 = src_q->nb[1], nbq2 = src_q->nb[2], nbq3 = src_q->nb[3]; - const size_t nbk1 = src_k->nb[1], nbk2 = src_k->nb[2], nbk3 = src_k->nb[3]; - const size_t nbv1 = src_v->nb[1], nbv2 = src_v->nb[2], nbv3 = src_v->nb[3]; - const size_t nbg1 = src_g->nb[1], nbg2 = src_g->nb[2], nbg3 = src_g->nb[3]; - const size_t nbb1 = src_beta->nb[1], nbb2 = src_beta->nb[2], nbb3 = src_beta->nb[3]; - - const int64_t rq3 = (neq3 > 0) ? n_seqs / neq3 : 1; - const int64_t rk3 = (nek3 > 0) ? n_seqs / nek3 : 1; - - // Output layout: [attn_scores | new_states] - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; - const size_t state_out_offset = attn_score_elems * nb_f32; - - // Shapes for per-head operations - int64_t ne_s[2] = { S_v, S_v }; - size_t nb_s[2] = { nb_f32, S_v * nb_f32 }; - int64_t ne_vec[1] = { S_v }; - size_t nb_vec[1] = { nb_f32 }; - int64_t ne_sc[1] = { 1 }; - size_t nb_sc[1] = { nb_f32 }; - int64_t ne_g_bc[2] = { S_v, 1 }; // for KDA gate broadcast - size_t nb_g_bc[2] = { nb_f32, S_v * nb_f32 }; - - // Copy input state to output state area - { - int64_t ne_flat[1] = { S_v * S_v * H * n_seqs }; - size_t nb_flat[1] = { nb_f32 }; - acl_tensor_ptr acl_sin = ggml_cann_create_tensor( - src_state->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1); - acl_tensor_ptr acl_sout = ggml_cann_create_tensor( - dst->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1, ACL_FORMAT_ND, state_out_offset); - cann_copy(ctx, acl_sin.get(), acl_sout.get()); - } - - for (int64_t s = 0; s < n_seqs; s++) { - for (int64_t h = 0; h < H; h++) { - const size_t s_off = state_out_offset + ((s * H + h) * S_v * S_v) * nb_f32; - - const int64_t iq1 = h % neq1; - const int64_t ik1 = h % nek1; - const int64_t iq3 = s / rq3; - const int64_t ik3 = s / rk3; - - for (int64_t t = 0; t < n_tokens; t++) { - // State matrix view [S_v, S_v] (transposed storage: M[j][i] = S[i][j]) - // Mv(M, k) = S^T @ k, Mv(M, q) = S^T @ q - acl_tensor_ptr acl_s_mat = ggml_cann_create_tensor( - dst->data, ACL_FLOAT, nb_f32, ne_s, nb_s, 2, ACL_FORMAT_ND, s_off); - - // Input tensor views - const size_t q_off = iq3 * nbq3 + t * nbq2 + iq1 * nbq1; - const size_t k_off = ik3 * nbk3 + t * nbk2 + ik1 * nbk1; - const size_t v_off = s * nbv3 + t * nbv2 + h * nbv1; - const size_t beta_off = s * nbb3 + t * nbb2 + h * nbb1; - const size_t g_off = s * nbg3 + t * nbg2 + h * nbg1; - - acl_tensor_ptr acl_q = ggml_cann_create_tensor( - src_q->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, q_off); - acl_tensor_ptr acl_k = ggml_cann_create_tensor( - src_k->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, k_off); - acl_tensor_ptr acl_v = ggml_cann_create_tensor( - src_v->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, v_off); - acl_tensor_ptr acl_beta = ggml_cann_create_tensor( - src_beta->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, beta_off); - - // Step 1: State decay S *= exp(g) - if (kda) { - // KDA mode: M[j][i] *= exp(g[i]) for all j - ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), S_v * nb_f32); - acl_tensor_ptr acl_g_src = ggml_cann_create_tensor( - src_g->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, g_off); - acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor( - g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1); - cann_copy(ctx, acl_g_src.get(), acl_g_exp.get()); - aclnn_exp(ctx, acl_g_exp.get()); - // Broadcast: GGML [S_v,1] -> CANN [1,S_v] broadcasts along rows - acl_tensor_ptr acl_g_exp_bc = ggml_cann_create_tensor( - g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_g_bc, nb_g_bc, 2); - aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp_bc.get(), nullptr); - } else { - // Scalar mode: M *= exp(g[0]) - ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), nb_f32); - acl_tensor_ptr acl_g_src = ggml_cann_create_tensor( - src_g->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, g_off); - acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor( - g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1); - cann_copy(ctx, acl_g_src.get(), acl_g_exp.get()); - aclnn_exp(ctx, acl_g_exp.get()); - aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp.get(), nullptr); - } - - // Step 2: delta = (v - M @ k) * beta - ggml_cann_pool_alloc kv_alloc(ctx.pool(), S_v * nb_f32); - acl_tensor_ptr acl_kv = ggml_cann_create_tensor( - kv_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1); - GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_k.get(), acl_kv.get(), 1); - - ggml_cann_pool_alloc delta_alloc(ctx.pool(), S_v * nb_f32); - acl_tensor_ptr acl_delta = ggml_cann_create_tensor( - delta_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1); - aclnn_sub(ctx, acl_v.get(), acl_kv.get(), acl_delta.get()); - aclnn_mul(ctx, acl_delta.get(), acl_beta.get(), nullptr); - - // Step 3: State update M += delta ⊗ k (outer product) - ggml_cann_pool_alloc outer_alloc(ctx.pool(), S_v * S_v * nb_f32); - acl_tensor_ptr acl_outer = ggml_cann_create_tensor( - outer_alloc.get(), ACL_FLOAT, nb_f32, ne_s, nb_s, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_delta.get(), acl_k.get(), acl_outer.get()); - aclnn_add(ctx, acl_s_mat.get(), acl_outer.get(), nullptr); - - // Step 4: Output attn = M @ q * scale - const size_t attn_off = ((s * n_tokens * H + t * H + h) * S_v) * nb_f32; - float *attn_ptr = (float *)((char *)dst->data + attn_off); - acl_tensor_ptr acl_attn_out = ggml_cann_create_tensor(attn_ptr, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND); - GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_q.get(), acl_attn_out.get(), 1); - aclnn_muls(ctx, acl_attn_out.get(), scale, nullptr, true); - } - } - } -} - -static void ggml_cann_gated_delta_net_math(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - // semantic path using lower-level ACL math operators - // (already ensured conditions: kda=0, t<=8, H,S_v<=256, contiguous) - ggml_tensor * src_q = dst->src[0]; - ggml_tensor * src_k = dst->src[1]; - ggml_tensor * src_v = dst->src[2]; - ggml_tensor * src_g = dst->src[3]; - ggml_tensor * src_beta = dst->src[4]; - ggml_tensor * src_state = dst->src[5]; - - const int64_t S_v = src_v->ne[0]; - const int64_t H = src_v->ne[1]; - const int64_t n_tokens = src_v->ne[2]; - const int64_t n_seqs = src_v->ne[3]; - const float scale = 1.0f / sqrtf((float) S_v); - - const size_t nb_f32 = sizeof(float); - const int64_t state_elems = S_v * S_v * H * n_seqs; - const size_t state_out_offset = (S_v * H * n_tokens * n_seqs) * nb_f32; - - // copy state block before update - memcpy((char *)dst->data + state_out_offset, src_state->data, state_elems * nb_f32); - - int64_t ne_mat[2] = { S_v, S_v }; - size_t nb_mat[2] = { nb_f32, nb_f32 * S_v }; - int64_t ne_vec_col[2] = { S_v, 1 }; - size_t nb_vec_col[2] = { nb_f32, nb_f32 * S_v }; - int64_t ne_vec_row[2] = { 1, S_v }; - size_t nb_vec_row[2] = { nb_f32 * S_v, nb_f32 }; - - for (int64_t s = 0; s < n_seqs; s++) { - for (int64_t h = 0; h < H; h++) { - float *state_ptr = (float *)((char *)dst->data + state_out_offset + ((s * H + h) * S_v * S_v) * nb_f32); - acl_tensor_ptr acl_state = ggml_cann_create_tensor(state_ptr, ACL_FLOAT, nb_f32, ne_mat, nb_mat, 2); - - for (int64_t t = 0; t < n_tokens; t++) { - float *q_ptr = (float *)((char *)src_q->data + h * src_q->nb[1] + t * src_q->nb[2] + s * src_q->nb[3]); - float *k_ptr = (float *)((char *)src_k->data + h * src_k->nb[1] + t * src_k->nb[2] + s * src_k->nb[3]); - float *v_ptr = (float *)((char *)src_v->data + h * src_v->nb[1] + t * src_v->nb[2] + s * src_v->nb[3]); - - float beta_val = *(float *)((char *)src_beta->data + h * src_beta->nb[1] + t * src_beta->nb[2] + s * src_beta->nb[3]); - float g_val = *(float *)((char *)src_g->data + h * src_g->nb[1] + t * src_g->nb[2] + s * src_g->nb[3]); - - acl_tensor_ptr acl_q = ggml_cann_create_tensor(q_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2); - acl_tensor_ptr acl_k = ggml_cann_create_tensor(k_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2); - acl_tensor_ptr acl_k_t = ggml_cann_create_tensor(k_ptr, ACL_FLOAT, nb_f32, ne_vec_row, nb_vec_row, 2); - acl_tensor_ptr acl_v = ggml_cann_create_tensor(v_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2); - - // state decay: state *= exp(g) - if (g_val != 0.0f) { - aclnn_muls(ctx, acl_state.get(), expf(g_val), nullptr, true); - } - - // m_k = state @ k - ggml_cann_pool_alloc mk_alloc(ctx.pool(), S_v * nb_f32); - acl_tensor_ptr acl_mk = ggml_cann_create_tensor(mk_alloc.get(), ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_state.get(), acl_k.get(), acl_mk.get(), 1); - - // delta = (v - m_k) * beta - ggml_cann_pool_alloc delta_alloc(ctx.pool(), S_v * nb_f32); - acl_tensor_ptr acl_delta = ggml_cann_create_tensor(delta_alloc.get(), ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2); - aclnn_sub(ctx, acl_v.get(), acl_mk.get(), acl_delta.get()); - aclnn_muls(ctx, acl_delta.get(), beta_val, nullptr, true); - - // outer = delta @ k^T - ggml_cann_pool_alloc outer_alloc(ctx.pool(), S_v * S_v * nb_f32); - acl_tensor_ptr acl_outer = ggml_cann_create_tensor(outer_alloc.get(), ACL_FLOAT, nb_f32, ne_mat, nb_mat, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_delta.get(), acl_k_t.get(), acl_outer.get(), 2); - - // state += outer - aclnn_add(ctx, acl_state.get(), acl_outer.get(), nullptr); - - // attn = scale * state @ q - float *attn_ptr = (float *)((char *)dst->data + (h * dst->nb[1] + t * dst->nb[2] + s * dst->nb[3])); - acl_tensor_ptr acl_attn = ggml_cann_create_tensor(attn_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_state.get(), acl_q.get(), acl_attn.get(), 1); - aclnn_muls(ctx, acl_attn.get(), scale, nullptr, true); - } - } - } -} - -// ggml_cann_gated_delta_net_batched +// ggml_cann_gated_delta_net // // Head-parallel implementation of the Gated Delta Net recurrence. // @@ -4446,7 +4218,7 @@ static void ggml_cann_gated_delta_net_math(ggml_backend_cann_context & ctx, ggml // Preconditions (checked by caller): // - no GQA: neq1==H, nek1==H, neq3==n_seqs, nek3==n_seqs // - F32 contiguous q, k, v, g, beta -static void ggml_cann_gated_delta_net_batched(ggml_backend_cann_context & ctx, ggml_tensor * dst) { +void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src_q = dst->src[0]; ggml_tensor * src_k = dst->src[1]; ggml_tensor * src_v = dst->src[2]; @@ -4624,177 +4396,3 @@ static void ggml_cann_gated_delta_net_batched(ggml_backend_cann_context & ctx, g } } -void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src_q = dst->src[0]; - ggml_tensor * src_k = dst->src[1]; - ggml_tensor * src_v = dst->src[2]; - ggml_tensor * src_g = dst->src[3]; - ggml_tensor * src_beta = dst->src[4]; - ggml_tensor * src_state = dst->src[5]; - - const int64_t S_v = src_v->ne[0]; - const int64_t H = src_v->ne[1]; - const int64_t n_tokens = src_v->ne[2]; - const int64_t n_seqs = src_v->ne[3]; - - const int64_t neq1 = src_q->ne[1]; - const int64_t nek1 = src_k->ne[1]; - const int64_t neq3 = src_q->ne[3]; - const int64_t nek3 = src_k->ne[3]; - - const bool kda = (src_g->ne[0] == S_v); - - // Batched path: batch over all H heads per timestep using BatchMatMul. - // Requires non-GQA (neq1==H, nek1==H) and contiguous F32 inputs. - // Reduces kernel launches by ~H× vs the naive per-head loop. - const bool use_batched = neq1 == H - && nek1 == H - && neq3 == n_seqs - && nek3 == n_seqs - && ggml_is_contiguous(src_q) - && ggml_is_contiguous(src_k) - && ggml_is_contiguous(src_v) - && ggml_is_contiguous(src_g) - && ggml_is_contiguous(src_beta) - && src_q->type == GGML_TYPE_F32; - - if (use_batched) { - ggml_cann_gated_delta_net_batched(ctx, dst); - return; - } - - ggml_cann_gated_delta_net_naive(ctx, dst); - return; - - // ── Dead code: fused aclnnRecurrentGatedDeltaRule path (disabled) ───────── - // Kept for reference; re-enable once the runtime crash is fixed. - // Constraints: no KDA, n_tokens<=8, S_v<=256, H<=256, no GQA, contiguous. - (void)kda; - - const int64_t T = n_seqs * n_tokens; - const float scale = 1.0f / sqrtf((float) S_v); - - const size_t nb_f32 = sizeof(float); - const size_t nb_bf16 = sizeof(uint16_t); - const size_t nb_i32 = sizeof(int32_t); - - // Output layout: [attn_scores | new_states] - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; - const size_t state_out_offset = attn_score_elems * nb_f32; - - // ---- Cast F32 inputs to BF16 ---- - - // Q: GGML [S_v, neq1, n_tokens, n_seqs] → 3D [S_v, neq1, T] → CANN (T, Nk, Dk) - int64_t ne_q[3] = { S_v, neq1, T }; - size_t nb_q_f32[3] = { nb_f32, S_v * nb_f32, S_v * neq1 * nb_f32 }; - size_t nb_q_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * neq1 * nb_bf16 }; - - ggml_cann_pool_alloc q_bf16_alloc(ctx.pool(), T * neq1 * S_v * nb_bf16); - acl_tensor_ptr acl_q_f32 = ggml_cann_create_tensor(src_q->data, ACL_FLOAT, nb_f32, ne_q, nb_q_f32, 3); - acl_tensor_ptr acl_q_bf16 = ggml_cann_create_tensor(q_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_q, nb_q_bf16, 3); - aclnn_cast(ctx, acl_q_f32.get(), acl_q_bf16.get(), ACL_BF16); - - // K: GGML [S_v, nek1, n_tokens, n_seqs] → 3D [S_v, nek1, T] → CANN (T, Nk, Dk) - int64_t ne_k[3] = { S_v, nek1, T }; - size_t nb_k_f32[3] = { nb_f32, S_v * nb_f32, S_v * nek1 * nb_f32 }; - size_t nb_k_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * nek1 * nb_bf16 }; - - ggml_cann_pool_alloc k_bf16_alloc(ctx.pool(), T * nek1 * S_v * nb_bf16); - acl_tensor_ptr acl_k_f32 = ggml_cann_create_tensor(src_k->data, ACL_FLOAT, nb_f32, ne_k, nb_k_f32, 3); - acl_tensor_ptr acl_k_bf16 = ggml_cann_create_tensor(k_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_k, nb_k_bf16, 3); - aclnn_cast(ctx, acl_k_f32.get(), acl_k_bf16.get(), ACL_BF16); - - // V: GGML [S_v, H, n_tokens, n_seqs] → 3D [S_v, H, T] → CANN (T, Nv, Dv) - int64_t ne_v[3] = { S_v, H, T }; - size_t nb_v_f32[3] = { nb_f32, S_v * nb_f32, S_v * H * nb_f32 }; - size_t nb_v_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * H * nb_bf16 }; - - ggml_cann_pool_alloc v_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16); - acl_tensor_ptr acl_v_f32 = ggml_cann_create_tensor(src_v->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3); - acl_tensor_ptr acl_v_bf16 = ggml_cann_create_tensor(v_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3); - aclnn_cast(ctx, acl_v_f32.get(), acl_v_bf16.get(), ACL_BF16); - - // Beta: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv) - int64_t ne_hf[2] = { H, T }; - size_t nb_hf_f32[2] = { nb_f32, H * nb_f32 }; - size_t nb_hf_bf16[2] = { nb_bf16, H * nb_bf16 }; - - ggml_cann_pool_alloc beta_bf16_alloc(ctx.pool(), T * H * nb_bf16); - acl_tensor_ptr acl_beta_f32 = ggml_cann_create_tensor(src_beta->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2); - acl_tensor_ptr acl_beta_bf16 = ggml_cann_create_tensor(beta_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_hf, nb_hf_bf16, 2); - aclnn_cast(ctx, acl_beta_f32.get(), acl_beta_bf16.get(), ACL_BF16); - - // Gate g: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv), stays F32 - acl_tensor_ptr acl_g = ggml_cann_create_tensor(src_g->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2); - - // State: GGML [S_v, S_v, H, n_seqs] → CANN (BlockNum, Nv, Dv, Dk) - // GGML stores M = S^T, but the recurrence applied to M has the same form as the - // standard delta rule, so M can be passed directly as the API's state parameter. - const int64_t state_elems = n_seqs * H * S_v * S_v; - int64_t ne_st[4] = { S_v, S_v, H, n_seqs }; - size_t nb_st_f32[4] = { nb_f32, S_v * nb_f32, S_v * S_v * nb_f32, S_v * S_v * H * nb_f32 }; - size_t nb_st_bf16[4] = { nb_bf16, S_v * nb_bf16, S_v * S_v * nb_bf16, S_v * S_v * H * nb_bf16 }; - - ggml_cann_pool_alloc state_bf16_alloc(ctx.pool(), state_elems * nb_bf16); - acl_tensor_ptr acl_state_f32 = ggml_cann_create_tensor(src_state->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4); - acl_tensor_ptr acl_state_bf16 = ggml_cann_create_tensor(state_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_st, nb_st_bf16, 4); - aclnn_cast(ctx, acl_state_f32.get(), acl_state_bf16.get(), ACL_BF16); - - // Output buffer in BF16: (T, Nv, Dv) — same layout as V - ggml_cann_pool_alloc out_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16); - acl_tensor_ptr acl_out_bf16 = ggml_cann_create_tensor(out_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3); - - // ---- Prepare INT32 auxiliary tensors ---- - - // actualSeqLengths: (B,) — each sequence has n_tokens tokens - std::vector host_seq_lens(n_seqs, (int32_t) n_tokens); - ggml_cann_pool_alloc asl_alloc(ctx.pool(), n_seqs * nb_i32); - ACL_CHECK(aclrtMemcpy(asl_alloc.get(), n_seqs * nb_i32, - host_seq_lens.data(), n_seqs * nb_i32, - ACL_MEMCPY_HOST_TO_DEVICE)); - int64_t ne_b[1] = { n_seqs }; - size_t nb_b[1] = { nb_i32 }; - acl_tensor_ptr acl_asl = ggml_cann_create_tensor(asl_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1); - - // ssmStateIndices: (T,) — token at seq s, pos t maps to state block s - std::vector host_ssm_idx(T); - for (int64_t s = 0; s < n_seqs; s++) { - for (int64_t t = 0; t < n_tokens; t++) { - host_ssm_idx[s * n_tokens + t] = (int32_t) s; - } - } - ggml_cann_pool_alloc ssm_alloc(ctx.pool(), T * nb_i32); - ACL_CHECK(aclrtMemcpy(ssm_alloc.get(), T * nb_i32, - host_ssm_idx.data(), T * nb_i32, - ACL_MEMCPY_HOST_TO_DEVICE)); - int64_t ne_T[1] = { T }; - size_t nb_T[1] = { nb_i32 }; - acl_tensor_ptr acl_ssm = ggml_cann_create_tensor(ssm_alloc.get(), ACL_INT32, nb_i32, ne_T, nb_T, 1); - - // numAcceptedTokens: (B,) — all tokens are accepted - ggml_cann_pool_alloc nat_alloc(ctx.pool(), n_seqs * nb_i32); - ACL_CHECK(aclrtMemcpy(nat_alloc.get(), n_seqs * nb_i32, - host_seq_lens.data(), n_seqs * nb_i32, - ACL_MEMCPY_HOST_TO_DEVICE)); - acl_tensor_ptr acl_nat = ggml_cann_create_tensor(nat_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1); - - // ---- Call fused operator ---- - GGML_CANN_CALL_ACLNN_OP(ctx, RecurrentGatedDeltaRule, - acl_q_bf16.get(), acl_k_bf16.get(), acl_v_bf16.get(), - acl_beta_bf16.get(), acl_state_bf16.get(), - acl_asl.get(), acl_ssm.get(), - acl_g.get(), nullptr, acl_nat.get(), - scale, acl_out_bf16.get()); - - // ---- Cast BF16 outputs back to F32 ---- - - // Attention output → dst[0 .. state_out_offset) - acl_tensor_ptr acl_dst_attn = ggml_cann_create_tensor( - dst->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3); - aclnn_cast(ctx, acl_out_bf16.get(), acl_dst_attn.get(), ACL_FLOAT); - - // Updated state → dst[state_out_offset .. end) - acl_tensor_ptr acl_dst_state = ggml_cann_create_tensor( - dst->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4, ACL_FORMAT_ND, state_out_offset); - aclnn_cast(ctx, acl_state_bf16.get(), acl_dst_state.get(), ACL_FLOAT); -} diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index eabd2d4979..2f87c649b9 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2568,8 +2568,29 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return true; } case GGML_OP_SSM_CONV: - case GGML_OP_GATED_DELTA_NET: return true; + case GGML_OP_GATED_DELTA_NET: + { + // Only the batched path (BatchMatMul over all heads) is efficient. + // Non-contiguous / GQA / non-F32 cases fall back to CPU. + const ggml_tensor * q = op->src[0]; + const ggml_tensor * k = op->src[1]; + const ggml_tensor * v = op->src[2]; + const ggml_tensor * g = op->src[3]; + const ggml_tensor * beta = op->src[4]; + const int64_t H = v->ne[1]; + const int64_t n_seqs = v->ne[3]; + return q->ne[1] == H + && k->ne[1] == H + && q->ne[3] == n_seqs + && k->ne[3] == n_seqs + && ggml_is_contiguous(q) + && ggml_is_contiguous(k) + && ggml_is_contiguous(v) + && ggml_is_contiguous(g) + && ggml_is_contiguous(beta) + && q->type == GGML_TYPE_F32; + } default: return false; }