diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 48c8964361..489bd6aa1c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9134,17 +9134,18 @@ static void ggml_compute_forward_ssm_scan_f32( const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave + //uint64_t t_start = ggml_time_us(); // dim for (int i1 = 0; i1 < nr; ++i1) { const int ii = i1 + h*nr; const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; -#if defined(GGML_SIMD) - #if defined(__ARM_FEATURE_SVE) + int np = 0; +#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE) const int ggml_f32_epr = svcntw(); const int ggml_f32_step = 1 * ggml_f32_epr; - const int np = (nc & ~(ggml_f32_step - 1)); + np = (nc & ~(ggml_f32_step - 1)); GGML_F32_VEC sum = GGML_F32_VEC_ZERO; @@ -9170,11 +9171,68 @@ static void ggml_compute_forward_ssm_scan_f32( } sumf = GGML_F32xt_REDUCE_ONE(sum); - #elif defined(__riscv_v_intrinsic) - // todo: RVV implementation - const int np = 0; - #else - const int np = (nc & ~(GGML_F32_STEP - 1)); +#elif defined(__riscv_v_intrinsic) + const int epr = __riscv_vsetvlmax_e32m2(); + const int step = epr * 2; + np = (nc & ~(step - 1)); + + vfloat32m2_t vdA_vec = __riscv_vfmv_v_f_f32m2(dA, epr); + vfloat32m2_t vx_dt_vec = __riscv_vfmv_v_f_f32m2(x_dt, epr); + + vfloat32m2_t vsum0 = __riscv_vfmv_v_f_f32m2(0.0f, epr); + vfloat32m2_t vsum1 = __riscv_vfmv_v_f_f32m2(0.0f, epr); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat32m2_t vs0_0 = __riscv_vle32_v_f32m2(s0 + i + ii*nc, epr); + vfloat32m2_t vB_0 = __riscv_vle32_v_f32m2(B + i + g*nc, epr); + vfloat32m2_t vC_0 = __riscv_vle32_v_f32m2(C + i + g*nc, epr); + + vfloat32m2_t vstate_0 = __riscv_vfmul_vv_f32m2(vs0_0, vdA_vec, epr); + vstate_0 = __riscv_vfmacc_vv_f32m2(vstate_0, vx_dt_vec, vB_0, epr); + vsum0 = __riscv_vfmacc_vv_f32m2(vsum0, vstate_0, vC_0, epr); + + __riscv_vse32_v_f32m2(s + i + ii*nc, vstate_0, epr); + + vfloat32m2_t vs0_1 = __riscv_vle32_v_f32m2(s0 + i + epr + ii*nc, epr); + vfloat32m2_t vB_1 = __riscv_vle32_v_f32m2(B + i + epr + g*nc, epr); + vfloat32m2_t vC_1 = __riscv_vle32_v_f32m2(C + i + epr + g*nc, epr); + + vfloat32m2_t vstate_1 = __riscv_vfmul_vv_f32m2(vs0_1, vdA_vec, epr); + vstate_1 = __riscv_vfmacc_vv_f32m2(vstate_1, vx_dt_vec, vB_1, epr); + vsum1 = __riscv_vfmacc_vv_f32m2(vsum1, vstate_1, vC_1, epr); + + __riscv_vse32_v_f32m2(s + i + epr + ii*nc, vstate_1, epr); + } + + // leftovers + int vl; + for (int i = np; i < nc; i += vl) { + vl = __riscv_vsetvl_e32m2(nc - i); + + vfloat32m2_t vs0 = __riscv_vle32_v_f32m2(s0 + i + ii*nc, vl); + vfloat32m2_t vB = __riscv_vle32_v_f32m2(B + i + g*nc, vl); + vfloat32m2_t vC = __riscv_vle32_v_f32m2(C + i + g*nc, vl); + + vfloat32m2_t vstate = __riscv_vfmul_vv_f32m2(vs0, vdA_vec, vl); + vstate = __riscv_vfmacc_vv_f32m2(vstate, vx_dt_vec, vB, vl); + + vsum0 = __riscv_vfmacc_vv_f32m2(vsum0, vstate, vC, vl); + + __riscv_vse32_v_f32m2(s + i + ii*nc, vstate, vl); + } + + vsum0 = __riscv_vfadd_vv_f32m2(vsum0, vsum1, epr); + + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m1_t vsum_reduced = __riscv_vfredusum_vs_f32m2_f32m1(vsum0, vzero, epr); + + sumf = __riscv_vfmv_f_s_f32m1_f32(vsum_reduced); + + np = nc; + +#elif defined(GGML_SIMD) + np = (nc & ~(GGML_F32_STEP - 1)); GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; @@ -9204,9 +9262,6 @@ static void ggml_compute_forward_ssm_scan_f32( // reduce sum0..sum3 to sum0 GGML_F32_VEC_REDUCE(sumf, sum); - #endif -#else - const int np = 0; #endif // d_state for (int i0 = np; i0 < nc; ++i0) { @@ -9220,6 +9275,8 @@ static void ggml_compute_forward_ssm_scan_f32( } y[ii] = sumf; } + //uint64_t t_end = ggml_time_us(); + //printf("Mamba2 dim runtime %.3f ms\n\n", (t_end - t_start) / 1000.0f); } } else { // Mamba-1 has an element-wise decay factor for the states