Merge 8a7e47a8b6 into 2634ed207a
This commit is contained in:
commit
744a8889f9
|
|
@ -9134,17 +9134,18 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
const float dA = expf(dt_soft_plus * A[h]);
|
||||
const int g = h / (nh / ng); // repeat_interleave
|
||||
|
||||
//uint64_t t_start = ggml_time_us();
|
||||
// dim
|
||||
for (int i1 = 0; i1 < nr; ++i1) {
|
||||
const int ii = i1 + h*nr;
|
||||
const float x_dt = x[ii] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
#if defined(GGML_SIMD)
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
int np = 0;
|
||||
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
|
||||
const int ggml_f32_epr = svcntw();
|
||||
const int ggml_f32_step = 1 * ggml_f32_epr;
|
||||
|
||||
const int np = (nc & ~(ggml_f32_step - 1));
|
||||
np = (nc & ~(ggml_f32_step - 1));
|
||||
|
||||
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
|
||||
|
||||
|
|
@ -9170,11 +9171,68 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
}
|
||||
|
||||
sumf = GGML_F32xt_REDUCE_ONE(sum);
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
// todo: RVV implementation
|
||||
const int np = 0;
|
||||
#else
|
||||
const int np = (nc & ~(GGML_F32_STEP - 1));
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
const int epr = __riscv_vsetvlmax_e32m2();
|
||||
const int step = epr * 2;
|
||||
np = (nc & ~(step - 1));
|
||||
|
||||
vfloat32m2_t vdA_vec = __riscv_vfmv_v_f_f32m2(dA, epr);
|
||||
vfloat32m2_t vx_dt_vec = __riscv_vfmv_v_f_f32m2(x_dt, epr);
|
||||
|
||||
vfloat32m2_t vsum0 = __riscv_vfmv_v_f_f32m2(0.0f, epr);
|
||||
vfloat32m2_t vsum1 = __riscv_vfmv_v_f_f32m2(0.0f, epr);
|
||||
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat32m2_t vs0_0 = __riscv_vle32_v_f32m2(s0 + i + ii*nc, epr);
|
||||
vfloat32m2_t vB_0 = __riscv_vle32_v_f32m2(B + i + g*nc, epr);
|
||||
vfloat32m2_t vC_0 = __riscv_vle32_v_f32m2(C + i + g*nc, epr);
|
||||
|
||||
vfloat32m2_t vstate_0 = __riscv_vfmul_vv_f32m2(vs0_0, vdA_vec, epr);
|
||||
vstate_0 = __riscv_vfmacc_vv_f32m2(vstate_0, vx_dt_vec, vB_0, epr);
|
||||
vsum0 = __riscv_vfmacc_vv_f32m2(vsum0, vstate_0, vC_0, epr);
|
||||
|
||||
__riscv_vse32_v_f32m2(s + i + ii*nc, vstate_0, epr);
|
||||
|
||||
vfloat32m2_t vs0_1 = __riscv_vle32_v_f32m2(s0 + i + epr + ii*nc, epr);
|
||||
vfloat32m2_t vB_1 = __riscv_vle32_v_f32m2(B + i + epr + g*nc, epr);
|
||||
vfloat32m2_t vC_1 = __riscv_vle32_v_f32m2(C + i + epr + g*nc, epr);
|
||||
|
||||
vfloat32m2_t vstate_1 = __riscv_vfmul_vv_f32m2(vs0_1, vdA_vec, epr);
|
||||
vstate_1 = __riscv_vfmacc_vv_f32m2(vstate_1, vx_dt_vec, vB_1, epr);
|
||||
vsum1 = __riscv_vfmacc_vv_f32m2(vsum1, vstate_1, vC_1, epr);
|
||||
|
||||
__riscv_vse32_v_f32m2(s + i + epr + ii*nc, vstate_1, epr);
|
||||
}
|
||||
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < nc; i += vl) {
|
||||
vl = __riscv_vsetvl_e32m2(nc - i);
|
||||
|
||||
vfloat32m2_t vs0 = __riscv_vle32_v_f32m2(s0 + i + ii*nc, vl);
|
||||
vfloat32m2_t vB = __riscv_vle32_v_f32m2(B + i + g*nc, vl);
|
||||
vfloat32m2_t vC = __riscv_vle32_v_f32m2(C + i + g*nc, vl);
|
||||
|
||||
vfloat32m2_t vstate = __riscv_vfmul_vv_f32m2(vs0, vdA_vec, vl);
|
||||
vstate = __riscv_vfmacc_vv_f32m2(vstate, vx_dt_vec, vB, vl);
|
||||
|
||||
vsum0 = __riscv_vfmacc_vv_f32m2(vsum0, vstate, vC, vl);
|
||||
|
||||
__riscv_vse32_v_f32m2(s + i + ii*nc, vstate, vl);
|
||||
}
|
||||
|
||||
vsum0 = __riscv_vfadd_vv_f32m2(vsum0, vsum1, epr);
|
||||
|
||||
vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
||||
vfloat32m1_t vsum_reduced = __riscv_vfredusum_vs_f32m2_f32m1(vsum0, vzero, epr);
|
||||
|
||||
sumf = __riscv_vfmv_f_s_f32m1_f32(vsum_reduced);
|
||||
|
||||
np = nc;
|
||||
|
||||
#elif defined(GGML_SIMD)
|
||||
np = (nc & ~(GGML_F32_STEP - 1));
|
||||
|
||||
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||
|
||||
|
|
@ -9204,9 +9262,6 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
|
||||
// reduce sum0..sum3 to sum0
|
||||
GGML_F32_VEC_REDUCE(sumf, sum);
|
||||
#endif
|
||||
#else
|
||||
const int np = 0;
|
||||
#endif
|
||||
// d_state
|
||||
for (int i0 = np; i0 < nc; ++i0) {
|
||||
|
|
@ -9220,6 +9275,8 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||
}
|
||||
y[ii] = sumf;
|
||||
}
|
||||
//uint64_t t_end = ggml_time_us();
|
||||
//printf("Mamba2 dim runtime %.3f ms\n\n", (t_end - t_start) / 1000.0f);
|
||||
}
|
||||
} else {
|
||||
// Mamba-1 has an element-wise decay factor for the states
|
||||
|
|
|
|||
Loading…
Reference in New Issue