CANN: simplify GATED_DELTA_NET implementation
- Remove dead code: _math and _naive variants are no longer needed - Rename _batched to the public entry point ggml_cann_gated_delta_net - In supports_op, return false for non-contiguous / GQA / non-F32 cases so the framework falls back to CPU instead of running the slow naive path - The single remaining implementation uses aclnnBatchMatMul over all H heads per timestep, reducing kernel launches to O(n_seqs * n_tokens)
This commit is contained in:
parent
3707b58628
commit
11e78d8499
|
|
@ -4189,235 +4189,7 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor *
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_cann_gated_delta_net_naive(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
ggml_tensor * src_g = dst->src[3];
|
||||
ggml_tensor * src_beta = dst->src[4];
|
||||
ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
const int64_t S_v = src_v->ne[0];
|
||||
const int64_t H = src_v->ne[1];
|
||||
const int64_t n_tokens = src_v->ne[2];
|
||||
const int64_t n_seqs = src_v->ne[3];
|
||||
|
||||
const int64_t neq1 = src_q->ne[1];
|
||||
const int64_t nek1 = src_k->ne[1];
|
||||
const int64_t neq3 = src_q->ne[3];
|
||||
const int64_t nek3 = src_k->ne[3];
|
||||
|
||||
const bool kda = (src_g->ne[0] == S_v);
|
||||
|
||||
const float scale = 1.0f / sqrtf((float) S_v);
|
||||
|
||||
const size_t nb_f32 = sizeof(float);
|
||||
|
||||
// Q/K strides (may differ from V for GQA)
|
||||
const size_t nbq1 = src_q->nb[1], nbq2 = src_q->nb[2], nbq3 = src_q->nb[3];
|
||||
const size_t nbk1 = src_k->nb[1], nbk2 = src_k->nb[2], nbk3 = src_k->nb[3];
|
||||
const size_t nbv1 = src_v->nb[1], nbv2 = src_v->nb[2], nbv3 = src_v->nb[3];
|
||||
const size_t nbg1 = src_g->nb[1], nbg2 = src_g->nb[2], nbg3 = src_g->nb[3];
|
||||
const size_t nbb1 = src_beta->nb[1], nbb2 = src_beta->nb[2], nbb3 = src_beta->nb[3];
|
||||
|
||||
const int64_t rq3 = (neq3 > 0) ? n_seqs / neq3 : 1;
|
||||
const int64_t rk3 = (nek3 > 0) ? n_seqs / nek3 : 1;
|
||||
|
||||
// Output layout: [attn_scores | new_states]
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
const size_t state_out_offset = attn_score_elems * nb_f32;
|
||||
|
||||
// Shapes for per-head operations
|
||||
int64_t ne_s[2] = { S_v, S_v };
|
||||
size_t nb_s[2] = { nb_f32, S_v * nb_f32 };
|
||||
int64_t ne_vec[1] = { S_v };
|
||||
size_t nb_vec[1] = { nb_f32 };
|
||||
int64_t ne_sc[1] = { 1 };
|
||||
size_t nb_sc[1] = { nb_f32 };
|
||||
int64_t ne_g_bc[2] = { S_v, 1 }; // for KDA gate broadcast
|
||||
size_t nb_g_bc[2] = { nb_f32, S_v * nb_f32 };
|
||||
|
||||
// Copy input state to output state area
|
||||
{
|
||||
int64_t ne_flat[1] = { S_v * S_v * H * n_seqs };
|
||||
size_t nb_flat[1] = { nb_f32 };
|
||||
acl_tensor_ptr acl_sin = ggml_cann_create_tensor(
|
||||
src_state->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1);
|
||||
acl_tensor_ptr acl_sout = ggml_cann_create_tensor(
|
||||
dst->data, ACL_FLOAT, nb_f32, ne_flat, nb_flat, 1, ACL_FORMAT_ND, state_out_offset);
|
||||
cann_copy(ctx, acl_sin.get(), acl_sout.get());
|
||||
}
|
||||
|
||||
for (int64_t s = 0; s < n_seqs; s++) {
|
||||
for (int64_t h = 0; h < H; h++) {
|
||||
const size_t s_off = state_out_offset + ((s * H + h) * S_v * S_v) * nb_f32;
|
||||
|
||||
const int64_t iq1 = h % neq1;
|
||||
const int64_t ik1 = h % nek1;
|
||||
const int64_t iq3 = s / rq3;
|
||||
const int64_t ik3 = s / rk3;
|
||||
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
// State matrix view [S_v, S_v] (transposed storage: M[j][i] = S[i][j])
|
||||
// Mv(M, k) = S^T @ k, Mv(M, q) = S^T @ q
|
||||
acl_tensor_ptr acl_s_mat = ggml_cann_create_tensor(
|
||||
dst->data, ACL_FLOAT, nb_f32, ne_s, nb_s, 2, ACL_FORMAT_ND, s_off);
|
||||
|
||||
// Input tensor views
|
||||
const size_t q_off = iq3 * nbq3 + t * nbq2 + iq1 * nbq1;
|
||||
const size_t k_off = ik3 * nbk3 + t * nbk2 + ik1 * nbk1;
|
||||
const size_t v_off = s * nbv3 + t * nbv2 + h * nbv1;
|
||||
const size_t beta_off = s * nbb3 + t * nbb2 + h * nbb1;
|
||||
const size_t g_off = s * nbg3 + t * nbg2 + h * nbg1;
|
||||
|
||||
acl_tensor_ptr acl_q = ggml_cann_create_tensor(
|
||||
src_q->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, q_off);
|
||||
acl_tensor_ptr acl_k = ggml_cann_create_tensor(
|
||||
src_k->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, k_off);
|
||||
acl_tensor_ptr acl_v = ggml_cann_create_tensor(
|
||||
src_v->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, v_off);
|
||||
acl_tensor_ptr acl_beta = ggml_cann_create_tensor(
|
||||
src_beta->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, beta_off);
|
||||
|
||||
// Step 1: State decay S *= exp(g)
|
||||
if (kda) {
|
||||
// KDA mode: M[j][i] *= exp(g[i]) for all j
|
||||
ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), S_v * nb_f32);
|
||||
acl_tensor_ptr acl_g_src = ggml_cann_create_tensor(
|
||||
src_g->data, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND, g_off);
|
||||
acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor(
|
||||
g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1);
|
||||
cann_copy(ctx, acl_g_src.get(), acl_g_exp.get());
|
||||
aclnn_exp(ctx, acl_g_exp.get());
|
||||
// Broadcast: GGML [S_v,1] -> CANN [1,S_v] broadcasts along rows
|
||||
acl_tensor_ptr acl_g_exp_bc = ggml_cann_create_tensor(
|
||||
g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_g_bc, nb_g_bc, 2);
|
||||
aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp_bc.get(), nullptr);
|
||||
} else {
|
||||
// Scalar mode: M *= exp(g[0])
|
||||
ggml_cann_pool_alloc g_exp_alloc(ctx.pool(), nb_f32);
|
||||
acl_tensor_ptr acl_g_src = ggml_cann_create_tensor(
|
||||
src_g->data, ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1, ACL_FORMAT_ND, g_off);
|
||||
acl_tensor_ptr acl_g_exp = ggml_cann_create_tensor(
|
||||
g_exp_alloc.get(), ACL_FLOAT, nb_f32, ne_sc, nb_sc, 1);
|
||||
cann_copy(ctx, acl_g_src.get(), acl_g_exp.get());
|
||||
aclnn_exp(ctx, acl_g_exp.get());
|
||||
aclnn_mul(ctx, acl_s_mat.get(), acl_g_exp.get(), nullptr);
|
||||
}
|
||||
|
||||
// Step 2: delta = (v - M @ k) * beta
|
||||
ggml_cann_pool_alloc kv_alloc(ctx.pool(), S_v * nb_f32);
|
||||
acl_tensor_ptr acl_kv = ggml_cann_create_tensor(
|
||||
kv_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_k.get(), acl_kv.get(), 1);
|
||||
|
||||
ggml_cann_pool_alloc delta_alloc(ctx.pool(), S_v * nb_f32);
|
||||
acl_tensor_ptr acl_delta = ggml_cann_create_tensor(
|
||||
delta_alloc.get(), ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1);
|
||||
aclnn_sub(ctx, acl_v.get(), acl_kv.get(), acl_delta.get());
|
||||
aclnn_mul(ctx, acl_delta.get(), acl_beta.get(), nullptr);
|
||||
|
||||
// Step 3: State update M += delta ⊗ k (outer product)
|
||||
ggml_cann_pool_alloc outer_alloc(ctx.pool(), S_v * S_v * nb_f32);
|
||||
acl_tensor_ptr acl_outer = ggml_cann_create_tensor(
|
||||
outer_alloc.get(), ACL_FLOAT, nb_f32, ne_s, nb_s, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_delta.get(), acl_k.get(), acl_outer.get());
|
||||
aclnn_add(ctx, acl_s_mat.get(), acl_outer.get(), nullptr);
|
||||
|
||||
// Step 4: Output attn = M @ q * scale
|
||||
const size_t attn_off = ((s * n_tokens * H + t * H + h) * S_v) * nb_f32;
|
||||
float *attn_ptr = (float *)((char *)dst->data + attn_off);
|
||||
acl_tensor_ptr acl_attn_out = ggml_cann_create_tensor(attn_ptr, ACL_FLOAT, nb_f32, ne_vec, nb_vec, 1, ACL_FORMAT_ND);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_mat.get(), acl_q.get(), acl_attn_out.get(), 1);
|
||||
aclnn_muls(ctx, acl_attn_out.get(), scale, nullptr, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cann_gated_delta_net_math(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
// semantic path using lower-level ACL math operators
|
||||
// (already ensured conditions: kda=0, t<=8, H,S_v<=256, contiguous)
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
ggml_tensor * src_g = dst->src[3];
|
||||
ggml_tensor * src_beta = dst->src[4];
|
||||
ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
const int64_t S_v = src_v->ne[0];
|
||||
const int64_t H = src_v->ne[1];
|
||||
const int64_t n_tokens = src_v->ne[2];
|
||||
const int64_t n_seqs = src_v->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) S_v);
|
||||
|
||||
const size_t nb_f32 = sizeof(float);
|
||||
const int64_t state_elems = S_v * S_v * H * n_seqs;
|
||||
const size_t state_out_offset = (S_v * H * n_tokens * n_seqs) * nb_f32;
|
||||
|
||||
// copy state block before update
|
||||
memcpy((char *)dst->data + state_out_offset, src_state->data, state_elems * nb_f32);
|
||||
|
||||
int64_t ne_mat[2] = { S_v, S_v };
|
||||
size_t nb_mat[2] = { nb_f32, nb_f32 * S_v };
|
||||
int64_t ne_vec_col[2] = { S_v, 1 };
|
||||
size_t nb_vec_col[2] = { nb_f32, nb_f32 * S_v };
|
||||
int64_t ne_vec_row[2] = { 1, S_v };
|
||||
size_t nb_vec_row[2] = { nb_f32 * S_v, nb_f32 };
|
||||
|
||||
for (int64_t s = 0; s < n_seqs; s++) {
|
||||
for (int64_t h = 0; h < H; h++) {
|
||||
float *state_ptr = (float *)((char *)dst->data + state_out_offset + ((s * H + h) * S_v * S_v) * nb_f32);
|
||||
acl_tensor_ptr acl_state = ggml_cann_create_tensor(state_ptr, ACL_FLOAT, nb_f32, ne_mat, nb_mat, 2);
|
||||
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
float *q_ptr = (float *)((char *)src_q->data + h * src_q->nb[1] + t * src_q->nb[2] + s * src_q->nb[3]);
|
||||
float *k_ptr = (float *)((char *)src_k->data + h * src_k->nb[1] + t * src_k->nb[2] + s * src_k->nb[3]);
|
||||
float *v_ptr = (float *)((char *)src_v->data + h * src_v->nb[1] + t * src_v->nb[2] + s * src_v->nb[3]);
|
||||
|
||||
float beta_val = *(float *)((char *)src_beta->data + h * src_beta->nb[1] + t * src_beta->nb[2] + s * src_beta->nb[3]);
|
||||
float g_val = *(float *)((char *)src_g->data + h * src_g->nb[1] + t * src_g->nb[2] + s * src_g->nb[3]);
|
||||
|
||||
acl_tensor_ptr acl_q = ggml_cann_create_tensor(q_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2);
|
||||
acl_tensor_ptr acl_k = ggml_cann_create_tensor(k_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2);
|
||||
acl_tensor_ptr acl_k_t = ggml_cann_create_tensor(k_ptr, ACL_FLOAT, nb_f32, ne_vec_row, nb_vec_row, 2);
|
||||
acl_tensor_ptr acl_v = ggml_cann_create_tensor(v_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2);
|
||||
|
||||
// state decay: state *= exp(g)
|
||||
if (g_val != 0.0f) {
|
||||
aclnn_muls(ctx, acl_state.get(), expf(g_val), nullptr, true);
|
||||
}
|
||||
|
||||
// m_k = state @ k
|
||||
ggml_cann_pool_alloc mk_alloc(ctx.pool(), S_v * nb_f32);
|
||||
acl_tensor_ptr acl_mk = ggml_cann_create_tensor(mk_alloc.get(), ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_state.get(), acl_k.get(), acl_mk.get(), 1);
|
||||
|
||||
// delta = (v - m_k) * beta
|
||||
ggml_cann_pool_alloc delta_alloc(ctx.pool(), S_v * nb_f32);
|
||||
acl_tensor_ptr acl_delta = ggml_cann_create_tensor(delta_alloc.get(), ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2);
|
||||
aclnn_sub(ctx, acl_v.get(), acl_mk.get(), acl_delta.get());
|
||||
aclnn_muls(ctx, acl_delta.get(), beta_val, nullptr, true);
|
||||
|
||||
// outer = delta @ k^T
|
||||
ggml_cann_pool_alloc outer_alloc(ctx.pool(), S_v * S_v * nb_f32);
|
||||
acl_tensor_ptr acl_outer = ggml_cann_create_tensor(outer_alloc.get(), ACL_FLOAT, nb_f32, ne_mat, nb_mat, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_delta.get(), acl_k_t.get(), acl_outer.get(), 2);
|
||||
|
||||
// state += outer
|
||||
aclnn_add(ctx, acl_state.get(), acl_outer.get(), nullptr);
|
||||
|
||||
// attn = scale * state @ q
|
||||
float *attn_ptr = (float *)((char *)dst->data + (h * dst->nb[1] + t * dst->nb[2] + s * dst->nb[3]));
|
||||
acl_tensor_ptr acl_attn = ggml_cann_create_tensor(attn_ptr, ACL_FLOAT, nb_f32, ne_vec_col, nb_vec_col, 2);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_state.get(), acl_q.get(), acl_attn.get(), 1);
|
||||
aclnn_muls(ctx, acl_attn.get(), scale, nullptr, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_cann_gated_delta_net_batched
|
||||
// ggml_cann_gated_delta_net
|
||||
//
|
||||
// Head-parallel implementation of the Gated Delta Net recurrence.
|
||||
//
|
||||
|
|
@ -4446,7 +4218,7 @@ static void ggml_cann_gated_delta_net_math(ggml_backend_cann_context & ctx, ggml
|
|||
// Preconditions (checked by caller):
|
||||
// - no GQA: neq1==H, nek1==H, neq3==n_seqs, nek3==n_seqs
|
||||
// - F32 contiguous q, k, v, g, beta
|
||||
static void ggml_cann_gated_delta_net_batched(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
|
|
@ -4624,177 +4396,3 @@ static void ggml_cann_gated_delta_net_batched(ggml_backend_cann_context & ctx, g
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cann_gated_delta_net(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
ggml_tensor * src_g = dst->src[3];
|
||||
ggml_tensor * src_beta = dst->src[4];
|
||||
ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
const int64_t S_v = src_v->ne[0];
|
||||
const int64_t H = src_v->ne[1];
|
||||
const int64_t n_tokens = src_v->ne[2];
|
||||
const int64_t n_seqs = src_v->ne[3];
|
||||
|
||||
const int64_t neq1 = src_q->ne[1];
|
||||
const int64_t nek1 = src_k->ne[1];
|
||||
const int64_t neq3 = src_q->ne[3];
|
||||
const int64_t nek3 = src_k->ne[3];
|
||||
|
||||
const bool kda = (src_g->ne[0] == S_v);
|
||||
|
||||
// Batched path: batch over all H heads per timestep using BatchMatMul.
|
||||
// Requires non-GQA (neq1==H, nek1==H) and contiguous F32 inputs.
|
||||
// Reduces kernel launches by ~H× vs the naive per-head loop.
|
||||
const bool use_batched = neq1 == H
|
||||
&& nek1 == H
|
||||
&& neq3 == n_seqs
|
||||
&& nek3 == n_seqs
|
||||
&& ggml_is_contiguous(src_q)
|
||||
&& ggml_is_contiguous(src_k)
|
||||
&& ggml_is_contiguous(src_v)
|
||||
&& ggml_is_contiguous(src_g)
|
||||
&& ggml_is_contiguous(src_beta)
|
||||
&& src_q->type == GGML_TYPE_F32;
|
||||
|
||||
if (use_batched) {
|
||||
ggml_cann_gated_delta_net_batched(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cann_gated_delta_net_naive(ctx, dst);
|
||||
return;
|
||||
|
||||
// ── Dead code: fused aclnnRecurrentGatedDeltaRule path (disabled) ─────────
|
||||
// Kept for reference; re-enable once the runtime crash is fixed.
|
||||
// Constraints: no KDA, n_tokens<=8, S_v<=256, H<=256, no GQA, contiguous.
|
||||
(void)kda;
|
||||
|
||||
const int64_t T = n_seqs * n_tokens;
|
||||
const float scale = 1.0f / sqrtf((float) S_v);
|
||||
|
||||
const size_t nb_f32 = sizeof(float);
|
||||
const size_t nb_bf16 = sizeof(uint16_t);
|
||||
const size_t nb_i32 = sizeof(int32_t);
|
||||
|
||||
// Output layout: [attn_scores | new_states]
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
const size_t state_out_offset = attn_score_elems * nb_f32;
|
||||
|
||||
// ---- Cast F32 inputs to BF16 ----
|
||||
|
||||
// Q: GGML [S_v, neq1, n_tokens, n_seqs] → 3D [S_v, neq1, T] → CANN (T, Nk, Dk)
|
||||
int64_t ne_q[3] = { S_v, neq1, T };
|
||||
size_t nb_q_f32[3] = { nb_f32, S_v * nb_f32, S_v * neq1 * nb_f32 };
|
||||
size_t nb_q_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * neq1 * nb_bf16 };
|
||||
|
||||
ggml_cann_pool_alloc q_bf16_alloc(ctx.pool(), T * neq1 * S_v * nb_bf16);
|
||||
acl_tensor_ptr acl_q_f32 = ggml_cann_create_tensor(src_q->data, ACL_FLOAT, nb_f32, ne_q, nb_q_f32, 3);
|
||||
acl_tensor_ptr acl_q_bf16 = ggml_cann_create_tensor(q_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_q, nb_q_bf16, 3);
|
||||
aclnn_cast(ctx, acl_q_f32.get(), acl_q_bf16.get(), ACL_BF16);
|
||||
|
||||
// K: GGML [S_v, nek1, n_tokens, n_seqs] → 3D [S_v, nek1, T] → CANN (T, Nk, Dk)
|
||||
int64_t ne_k[3] = { S_v, nek1, T };
|
||||
size_t nb_k_f32[3] = { nb_f32, S_v * nb_f32, S_v * nek1 * nb_f32 };
|
||||
size_t nb_k_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * nek1 * nb_bf16 };
|
||||
|
||||
ggml_cann_pool_alloc k_bf16_alloc(ctx.pool(), T * nek1 * S_v * nb_bf16);
|
||||
acl_tensor_ptr acl_k_f32 = ggml_cann_create_tensor(src_k->data, ACL_FLOAT, nb_f32, ne_k, nb_k_f32, 3);
|
||||
acl_tensor_ptr acl_k_bf16 = ggml_cann_create_tensor(k_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_k, nb_k_bf16, 3);
|
||||
aclnn_cast(ctx, acl_k_f32.get(), acl_k_bf16.get(), ACL_BF16);
|
||||
|
||||
// V: GGML [S_v, H, n_tokens, n_seqs] → 3D [S_v, H, T] → CANN (T, Nv, Dv)
|
||||
int64_t ne_v[3] = { S_v, H, T };
|
||||
size_t nb_v_f32[3] = { nb_f32, S_v * nb_f32, S_v * H * nb_f32 };
|
||||
size_t nb_v_bf16[3] = { nb_bf16, S_v * nb_bf16, S_v * H * nb_bf16 };
|
||||
|
||||
ggml_cann_pool_alloc v_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16);
|
||||
acl_tensor_ptr acl_v_f32 = ggml_cann_create_tensor(src_v->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3);
|
||||
acl_tensor_ptr acl_v_bf16 = ggml_cann_create_tensor(v_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3);
|
||||
aclnn_cast(ctx, acl_v_f32.get(), acl_v_bf16.get(), ACL_BF16);
|
||||
|
||||
// Beta: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv)
|
||||
int64_t ne_hf[2] = { H, T };
|
||||
size_t nb_hf_f32[2] = { nb_f32, H * nb_f32 };
|
||||
size_t nb_hf_bf16[2] = { nb_bf16, H * nb_bf16 };
|
||||
|
||||
ggml_cann_pool_alloc beta_bf16_alloc(ctx.pool(), T * H * nb_bf16);
|
||||
acl_tensor_ptr acl_beta_f32 = ggml_cann_create_tensor(src_beta->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2);
|
||||
acl_tensor_ptr acl_beta_bf16 = ggml_cann_create_tensor(beta_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_hf, nb_hf_bf16, 2);
|
||||
aclnn_cast(ctx, acl_beta_f32.get(), acl_beta_bf16.get(), ACL_BF16);
|
||||
|
||||
// Gate g: GGML [1, H, n_tokens, n_seqs] → 2D [H, T] → CANN (T, Nv), stays F32
|
||||
acl_tensor_ptr acl_g = ggml_cann_create_tensor(src_g->data, ACL_FLOAT, nb_f32, ne_hf, nb_hf_f32, 2);
|
||||
|
||||
// State: GGML [S_v, S_v, H, n_seqs] → CANN (BlockNum, Nv, Dv, Dk)
|
||||
// GGML stores M = S^T, but the recurrence applied to M has the same form as the
|
||||
// standard delta rule, so M can be passed directly as the API's state parameter.
|
||||
const int64_t state_elems = n_seqs * H * S_v * S_v;
|
||||
int64_t ne_st[4] = { S_v, S_v, H, n_seqs };
|
||||
size_t nb_st_f32[4] = { nb_f32, S_v * nb_f32, S_v * S_v * nb_f32, S_v * S_v * H * nb_f32 };
|
||||
size_t nb_st_bf16[4] = { nb_bf16, S_v * nb_bf16, S_v * S_v * nb_bf16, S_v * S_v * H * nb_bf16 };
|
||||
|
||||
ggml_cann_pool_alloc state_bf16_alloc(ctx.pool(), state_elems * nb_bf16);
|
||||
acl_tensor_ptr acl_state_f32 = ggml_cann_create_tensor(src_state->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4);
|
||||
acl_tensor_ptr acl_state_bf16 = ggml_cann_create_tensor(state_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_st, nb_st_bf16, 4);
|
||||
aclnn_cast(ctx, acl_state_f32.get(), acl_state_bf16.get(), ACL_BF16);
|
||||
|
||||
// Output buffer in BF16: (T, Nv, Dv) — same layout as V
|
||||
ggml_cann_pool_alloc out_bf16_alloc(ctx.pool(), T * H * S_v * nb_bf16);
|
||||
acl_tensor_ptr acl_out_bf16 = ggml_cann_create_tensor(out_bf16_alloc.get(), ACL_BF16, nb_bf16, ne_v, nb_v_bf16, 3);
|
||||
|
||||
// ---- Prepare INT32 auxiliary tensors ----
|
||||
|
||||
// actualSeqLengths: (B,) — each sequence has n_tokens tokens
|
||||
std::vector<int32_t> host_seq_lens(n_seqs, (int32_t) n_tokens);
|
||||
ggml_cann_pool_alloc asl_alloc(ctx.pool(), n_seqs * nb_i32);
|
||||
ACL_CHECK(aclrtMemcpy(asl_alloc.get(), n_seqs * nb_i32,
|
||||
host_seq_lens.data(), n_seqs * nb_i32,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
int64_t ne_b[1] = { n_seqs };
|
||||
size_t nb_b[1] = { nb_i32 };
|
||||
acl_tensor_ptr acl_asl = ggml_cann_create_tensor(asl_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1);
|
||||
|
||||
// ssmStateIndices: (T,) — token at seq s, pos t maps to state block s
|
||||
std::vector<int32_t> host_ssm_idx(T);
|
||||
for (int64_t s = 0; s < n_seqs; s++) {
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
host_ssm_idx[s * n_tokens + t] = (int32_t) s;
|
||||
}
|
||||
}
|
||||
ggml_cann_pool_alloc ssm_alloc(ctx.pool(), T * nb_i32);
|
||||
ACL_CHECK(aclrtMemcpy(ssm_alloc.get(), T * nb_i32,
|
||||
host_ssm_idx.data(), T * nb_i32,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
int64_t ne_T[1] = { T };
|
||||
size_t nb_T[1] = { nb_i32 };
|
||||
acl_tensor_ptr acl_ssm = ggml_cann_create_tensor(ssm_alloc.get(), ACL_INT32, nb_i32, ne_T, nb_T, 1);
|
||||
|
||||
// numAcceptedTokens: (B,) — all tokens are accepted
|
||||
ggml_cann_pool_alloc nat_alloc(ctx.pool(), n_seqs * nb_i32);
|
||||
ACL_CHECK(aclrtMemcpy(nat_alloc.get(), n_seqs * nb_i32,
|
||||
host_seq_lens.data(), n_seqs * nb_i32,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE));
|
||||
acl_tensor_ptr acl_nat = ggml_cann_create_tensor(nat_alloc.get(), ACL_INT32, nb_i32, ne_b, nb_b, 1);
|
||||
|
||||
// ---- Call fused operator ----
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, RecurrentGatedDeltaRule,
|
||||
acl_q_bf16.get(), acl_k_bf16.get(), acl_v_bf16.get(),
|
||||
acl_beta_bf16.get(), acl_state_bf16.get(),
|
||||
acl_asl.get(), acl_ssm.get(),
|
||||
acl_g.get(), nullptr, acl_nat.get(),
|
||||
scale, acl_out_bf16.get());
|
||||
|
||||
// ---- Cast BF16 outputs back to F32 ----
|
||||
|
||||
// Attention output → dst[0 .. state_out_offset)
|
||||
acl_tensor_ptr acl_dst_attn = ggml_cann_create_tensor(
|
||||
dst->data, ACL_FLOAT, nb_f32, ne_v, nb_v_f32, 3);
|
||||
aclnn_cast(ctx, acl_out_bf16.get(), acl_dst_attn.get(), ACL_FLOAT);
|
||||
|
||||
// Updated state → dst[state_out_offset .. end)
|
||||
acl_tensor_ptr acl_dst_state = ggml_cann_create_tensor(
|
||||
dst->data, ACL_FLOAT, nb_f32, ne_st, nb_st_f32, 4, ACL_FORMAT_ND, state_out_offset);
|
||||
aclnn_cast(ctx, acl_state_bf16.get(), acl_dst_state.get(), ACL_FLOAT);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2568,8 +2568,29 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
return true;
|
||||
}
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return true;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
// Only the batched path (BatchMatMul over all heads) is efficient.
|
||||
// Non-contiguous / GQA / non-F32 cases fall back to CPU.
|
||||
const ggml_tensor * q = op->src[0];
|
||||
const ggml_tensor * k = op->src[1];
|
||||
const ggml_tensor * v = op->src[2];
|
||||
const ggml_tensor * g = op->src[3];
|
||||
const ggml_tensor * beta = op->src[4];
|
||||
const int64_t H = v->ne[1];
|
||||
const int64_t n_seqs = v->ne[3];
|
||||
return q->ne[1] == H
|
||||
&& k->ne[1] == H
|
||||
&& q->ne[3] == n_seqs
|
||||
&& k->ne[3] == n_seqs
|
||||
&& ggml_is_contiguous(q)
|
||||
&& ggml_is_contiguous(k)
|
||||
&& ggml_is_contiguous(v)
|
||||
&& ggml_is_contiguous(g)
|
||||
&& ggml_is_contiguous(beta)
|
||||
&& q->type == GGML_TYPE_F32;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue