From 04719ae5171583d797c4342213ef84db6effcd27 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Wed, 4 Mar 2026 19:00:30 +0500 Subject: [PATCH] ggml-cpu: refactor; add rvv repacking for mxfp4 --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 260 +++++++-- ggml/src/ggml-cpu/repack.cpp | 715 +++++++++++++++++++++++- ggml/src/ggml-cpu/repack.h | 51 +- 3 files changed, 964 insertions(+), 62 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index af9ac7326f..fc0ee82bb2 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -271,7 +271,7 @@ void ggml_gemv_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -334,12 +334,13 @@ void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); assert(nr == 1); assert(nc % ncols_interleaved == 0); UNUSED(bs); + UNUSED(nr); const int num_k_blocks = n / QK_K; @@ -507,7 +508,6 @@ static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_ } } - void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); } @@ -522,7 +522,7 @@ void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -551,6 +551,10 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int l = 0; l < nb; l++) { vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -584,8 +588,6 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); // Accumulation for 2 sub-blocks. @@ -668,7 +670,7 @@ void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -697,6 +699,10 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int l = 0; l < nb; l++) { vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -730,8 +736,6 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); // Accumulation for 2 sub-blocks. @@ -892,6 +896,78 @@ void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); } + +template +static inline void ggml_gemv_mxfp4_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_mxfp4, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + + // 1xM Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 1xM integer accumulator + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK_MXFP4 / 2; i++) { + // Load `b_ptr`. + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + } + + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); + + float b_scales[ncols_interleaved]; + for (int i = 0; i < ncols_interleaved; i++) { + b_scales[i] = GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[i]); + } + const vfloat32m2_t b_e = __riscv_vle32_v_f32m2((const float *)&b_scales[0], ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d), ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +} #endif void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1122,7 +1198,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #if defined __riscv_zvfh template -void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -1221,7 +1297,7 @@ void ggml_gemm_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -1303,7 +1379,7 @@ void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -static void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); const int num_k_blocks = n / QK_K; const int N_ROWS_TILE = 4; @@ -1652,6 +1728,9 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -1733,14 +1812,10 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); - const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); - const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); - const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], ncols_interleaved); sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); @@ -1900,7 +1975,7 @@ void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -1936,6 +2011,9 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + // Load `dmins`. + const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved); + // We process 4 sub-blocks at once. const int vl = ncols_interleaved * 4; for (int j = 0; j < QK_K / 128; j++) { @@ -2017,14 +2095,10 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); - const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); - const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); - const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); - const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], ncols_interleaved); sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); @@ -2185,32 +2259,16 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } void ggml_gemm_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_zvfh ggml_gemm_q5_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); -#else - ggml_gemm_q5_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); -#endif } void ggml_gemm_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_zvfh ggml_gemm_q5_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); -#else - ggml_gemm_q5_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); -#endif } void ggml_gemm_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_zvfh ggml_gemm_q5_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); -#else - ggml_gemm_q5_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); -#endif } void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_zvfh ggml_gemm_q5_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); -#else - ggml_gemm_q5_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); -#endif } template @@ -2233,7 +2291,6 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { @@ -2260,8 +2317,6 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4], ncols_interleaved); const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 1], ncols_interleaved); @@ -2297,9 +2352,6 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); } } - return; -#endif - ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2314,4 +2366,110 @@ void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); } + +template +static inline void ggml_gemm_mxfp4_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_mxfp4, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + + // 4xM Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 4xM integer accumulators + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK_MXFP4 / 2; i++) { + // Load `b_ptr`. + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved); + + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved); + } + + // Do the final accumulation in i32 to prevent overflow. + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, ncols_interleaved); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, ncols_interleaved); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, ncols_interleaved); + + float b_scales[ncols_interleaved]; + for (int i = 0; i < ncols_interleaved; i++) { + b_scales[i] = GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[i]); + } + const vfloat32m2_t b_e = __riscv_vle32_v_f32m2((const float *)&b_scales[0], ncols_interleaved); + + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[0]), ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[1]), ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[2]), ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[3]), ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } +} + +void ggml_gemm_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +} #endif diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index ab0c9bda07..5daa0a9d97 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,4 +1,3 @@ -#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -725,6 +724,44 @@ static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } + +template +static inline void ggml_gemv_mxfp4_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} #endif #if defined __riscv_zvfh @@ -1157,8 +1194,448 @@ static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC } } } + +template +static inline void ggml_gemm_mxfp4_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][ncols_interleaved]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x * b_ptr = (const block_mxfp4x *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} #endif +template +static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +template +static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + const int q8_half_stride = 512; + const int q8_low_high_step = 256; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = + qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = + qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +template +static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +template +static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = + qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 256 + + (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1520,8 +1997,100 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); + constexpr int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + + + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + // qh indexing with 8-byte interleaving (like q5_K) + const int qh_byte_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_byte_l / 8; + const int qh_pos_l = qh_byte_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_byte_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_byte_h / 8; + const int qh_pos_h = qh_byte_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + + printf("w: %d %d, b: %d %d %d\n", q_l, a_l, l_4, hi_2_l, b_ptr[l].qh[qh_offset_h]); + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } } void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1855,6 +2424,20 @@ void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t b void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +void ggml_gemv_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_mxfp4_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); +} #endif void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2670,6 +3253,20 @@ void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t b void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +void ggml_gemm_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_mxfp4_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); +} #endif } // extern "C" @@ -3769,6 +4366,59 @@ static int repack_iq4_nl_to_iq4_nl_Mx1_bl(struct ggml_tensor * t, const void * G GGML_UNUSED(data_size); } + +template +static block_mxfp4x make_block_mxfp4xMx1(block_mxfp4 * in) { + block_mxfp4x out; + + for (int i = 0; i < nrows_interleaved; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * nrows_interleaved / 2; + + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + return out; +} + +template +static int repack_mxfp4_to_mxfp4_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x * dst = ( block_mxfp4x *)t->data; + + block_mxfp4 dst_tmp[nrows_interleaved]; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4xMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} #endif static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { @@ -4044,6 +4694,20 @@ template <> int repack(struct ggml_tensor * t, const void * template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_Mx1_bl<64>(t, data, data_size); } + +// MXFP4 +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_Mx1_bl<64>(t, data, data_size); +} #endif // gemv @@ -4205,6 +4869,20 @@ template <> void gemv(int n, float * s, siz template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif // gemm @@ -4366,6 +5044,20 @@ template <> void gemm(int n, float * s, siz template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +// MXFP4 +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif class tensor_traits_base : public ggml::cpu::tensor_traits { @@ -4816,6 +5508,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_32x1_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_64x1_q8_0; + + // MXFP4 + static const ggml::cpu::repack::tensor_traits mxfp4_8x1_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_64x1_q8_0; #endif if (cur->type == GGML_TYPE_Q4_0) { @@ -4899,6 +5597,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q5_K_8x4_q8_K; } + } if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { @@ -4976,6 +5675,18 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } + } else if (cur->type == GGML_TYPE_MXFP4) { + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &mxfp4_8x1_q8_0; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &mxfp4_16x1_q8_0; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &mxfp4_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &mxfp4_64x1_q8_0; } break; } + default: { return nullptr; } + } + #endif + } } return nullptr; diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 2a0b1d42a4..9a7cc4828e 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -100,6 +100,16 @@ static_assert(sizeof(block_q5_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 1 static_assert(sizeof(block_q5_Kx32) == sizeof(ggml_half) * 64 + K_SCALE_SIZE * 32 + QK_K * 20, "wrong q5_K block size/padding"); static_assert(sizeof(block_q5_Kx64) == sizeof(ggml_half) * 128 + K_SCALE_SIZE * 64 + QK_K * 40, "wrong q5_K block size/padding"); +struct block_q6_Kx8 { + ggml_half d[8]; + int8_t scales[QK_K / 16 * 8]; + uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +}; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, + "wrong q6_K block size/padding"); + struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -125,17 +135,23 @@ static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "w static_assert(sizeof(block_iq4_nlx32) == 32 * sizeof(ggml_half) + QK4_NL * 16, "wrong iq4_nlx32 block size/padding"); static_assert(sizeof(block_iq4_nlx64) == 64 * sizeof(ggml_half) + QK4_NL * 32, "wrong iq4_nlx64 block size/padding"); -struct block_mxfp4x4 { - uint8_t e[4]; - uint8_t qs[QK_MXFP4 * 2]; +template struct block_mxfp4x { + ggml_half e[N]; // deltas for `N` mxfp4 blocks + uint8_t qs[QK_MXFP4 * N / 2]; // nibbles / quants for N mxfp4 blocks }; -static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding"); -struct block_mxfp4x8 { - uint8_t e[8]; - uint8_t qs[QK_MXFP4 * 4]; -}; -static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); +using block_mxfp4x4 = block_mxfp4x<4>; +using block_mxfp4x8 = block_mxfp4x<8>; +using block_mxfp4x16 = block_mxfp4x<16>; +using block_mxfp4x32 = block_mxfp4x<32>; +using block_mxfp4x64 = block_mxfp4x<64>; + +static_assert(sizeof(block_mxfp4x4) == 4 * sizeof(ggml_half) + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding"); +static_assert(sizeof(block_mxfp4x8) == 8 * sizeof(ggml_half) + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); +static_assert(sizeof(block_mxfp4x16) == 16 * sizeof(ggml_half) + QK_MXFP4 * 8, "wrong mxfp4x16 block size/padding"); +static_assert(sizeof(block_mxfp4x32) == 32 * sizeof(ggml_half) + QK_MXFP4 * 16, "wrong mxfp4x32 block size/padding"); +static_assert(sizeof(block_mxfp4x64) == 64 * sizeof(ggml_half) + QK_MXFP4 * 32, "wrong mxfp4x64 block size/padding"); + #if defined(__cplusplus) extern "C" { @@ -179,6 +195,7 @@ void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -203,6 +220,10 @@ void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -227,6 +248,10 @@ void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif // Native implementations @@ -293,6 +318,10 @@ void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -317,6 +346,10 @@ void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif #if defined(__cplusplus)