This commit is contained in:
ixgbe 2026-02-01 09:20:22 -03:00 committed by GitHub
commit 744a8889f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 68 additions and 11 deletions

View File

@ -9134,17 +9134,18 @@ static void ggml_compute_forward_ssm_scan_f32(
const float dA = expf(dt_soft_plus * A[h]);
const int g = h / (nh / ng); // repeat_interleave
//uint64_t t_start = ggml_time_us();
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
int np = 0;
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
const int ggml_f32_epr = svcntw();
const int ggml_f32_step = 1 * ggml_f32_epr;
const int np = (nc & ~(ggml_f32_step - 1));
np = (nc & ~(ggml_f32_step - 1));
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
@ -9170,11 +9171,68 @@ static void ggml_compute_forward_ssm_scan_f32(
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
#elif defined(__riscv_v_intrinsic)
// todo: RVV implementation
const int np = 0;
#else
const int np = (nc & ~(GGML_F32_STEP - 1));
#elif defined(__riscv_v_intrinsic)
const int epr = __riscv_vsetvlmax_e32m2();
const int step = epr * 2;
np = (nc & ~(step - 1));
vfloat32m2_t vdA_vec = __riscv_vfmv_v_f_f32m2(dA, epr);
vfloat32m2_t vx_dt_vec = __riscv_vfmv_v_f_f32m2(x_dt, epr);
vfloat32m2_t vsum0 = __riscv_vfmv_v_f_f32m2(0.0f, epr);
vfloat32m2_t vsum1 = __riscv_vfmv_v_f_f32m2(0.0f, epr);
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat32m2_t vs0_0 = __riscv_vle32_v_f32m2(s0 + i + ii*nc, epr);
vfloat32m2_t vB_0 = __riscv_vle32_v_f32m2(B + i + g*nc, epr);
vfloat32m2_t vC_0 = __riscv_vle32_v_f32m2(C + i + g*nc, epr);
vfloat32m2_t vstate_0 = __riscv_vfmul_vv_f32m2(vs0_0, vdA_vec, epr);
vstate_0 = __riscv_vfmacc_vv_f32m2(vstate_0, vx_dt_vec, vB_0, epr);
vsum0 = __riscv_vfmacc_vv_f32m2(vsum0, vstate_0, vC_0, epr);
__riscv_vse32_v_f32m2(s + i + ii*nc, vstate_0, epr);
vfloat32m2_t vs0_1 = __riscv_vle32_v_f32m2(s0 + i + epr + ii*nc, epr);
vfloat32m2_t vB_1 = __riscv_vle32_v_f32m2(B + i + epr + g*nc, epr);
vfloat32m2_t vC_1 = __riscv_vle32_v_f32m2(C + i + epr + g*nc, epr);
vfloat32m2_t vstate_1 = __riscv_vfmul_vv_f32m2(vs0_1, vdA_vec, epr);
vstate_1 = __riscv_vfmacc_vv_f32m2(vstate_1, vx_dt_vec, vB_1, epr);
vsum1 = __riscv_vfmacc_vv_f32m2(vsum1, vstate_1, vC_1, epr);
__riscv_vse32_v_f32m2(s + i + epr + ii*nc, vstate_1, epr);
}
// leftovers
int vl;
for (int i = np; i < nc; i += vl) {
vl = __riscv_vsetvl_e32m2(nc - i);
vfloat32m2_t vs0 = __riscv_vle32_v_f32m2(s0 + i + ii*nc, vl);
vfloat32m2_t vB = __riscv_vle32_v_f32m2(B + i + g*nc, vl);
vfloat32m2_t vC = __riscv_vle32_v_f32m2(C + i + g*nc, vl);
vfloat32m2_t vstate = __riscv_vfmul_vv_f32m2(vs0, vdA_vec, vl);
vstate = __riscv_vfmacc_vv_f32m2(vstate, vx_dt_vec, vB, vl);
vsum0 = __riscv_vfmacc_vv_f32m2(vsum0, vstate, vC, vl);
__riscv_vse32_v_f32m2(s + i + ii*nc, vstate, vl);
}
vsum0 = __riscv_vfadd_vv_f32m2(vsum0, vsum1, epr);
vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1);
vfloat32m1_t vsum_reduced = __riscv_vfredusum_vs_f32m2_f32m1(vsum0, vzero, epr);
sumf = __riscv_vfmv_f_s_f32m1_f32(vsum_reduced);
np = nc;
#elif defined(GGML_SIMD)
np = (nc & ~(GGML_F32_STEP - 1));
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@ -9204,9 +9262,6 @@ static void ggml_compute_forward_ssm_scan_f32(
// reduce sum0..sum3 to sum0
GGML_F32_VEC_REDUCE(sumf, sum);
#endif
#else
const int np = 0;
#endif
// d_state
for (int i0 = np; i0 < nc; ++i0) {
@ -9220,6 +9275,8 @@ static void ggml_compute_forward_ssm_scan_f32(
}
y[ii] = sumf;
}
//uint64_t t_end = ggml_time_us();
//printf("Mamba2 dim runtime %.3f ms\n\n", (t_end - t_start) / 1000.0f);
}
} else {
// Mamba-1 has an element-wise decay factor for the states