ggml-cpu: refactor; add rvv repacking for mxfp4

This commit is contained in:
taimur-10x 2026-03-04 19:00:30 +05:00
parent 977beacc4e
commit 04719ae517
3 changed files with 964 additions and 62 deletions

View File

@ -271,7 +271,7 @@ void ggml_gemv_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
}
template<int ncols_interleaved>
void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int blocklen = 1;
@ -334,12 +334,13 @@ void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
}
template<int ncols_interleaved>
static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
assert(n % QK_K == 0);
assert(nr == 1);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
const int num_k_blocks = n / QK_K;
@ -507,7 +508,6 @@ static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_
}
}
void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc);
}
@ -522,7 +522,7 @@ void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
}
template<int ncols_interleaved>
void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int blocklen = 1;
@ -551,6 +551,10 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
for (int l = 0; l < nb; l++) {
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved);
// Load `dmins`.
const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved);
// We process 4 sub-blocks at once.
const int vl = ncols_interleaved * 4;
for (int j = 0; j < QK_K / 128; j++) {
@ -584,8 +588,6 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved);
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved);
const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved);
sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved);
// Accumulation for 2 sub-blocks.
@ -668,7 +670,7 @@ void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
}
template<int ncols_interleaved>
void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int blocklen = 1;
@ -697,6 +699,10 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
for (int l = 0; l < nb; l++) {
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved);
// Load `dmins`.
const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved);
// We process 4 sub-blocks at once.
const int vl = ncols_interleaved * 4;
for (int j = 0; j < QK_K / 128; j++) {
@ -730,8 +736,6 @@ void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved);
bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved);
const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved);
sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved);
// Accumulation for 2 sub-blocks.
@ -892,6 +896,78 @@ void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc);
}
template<int ncols_interleaved>
static inline void ggml_gemv_mxfp4_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int blocklen = 1;
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(s);
UNUSED(bs);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr);
UNUSED(nc);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_mxfp4, 16);
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_mxfp4x<ncols_interleaved> * b_ptr = (const block_mxfp4x<ncols_interleaved> *) vx + (x * nb);
// 1xM Accumulator
vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved);
for (int l = 0; l < nb; l++) {
// 1xM integer accumulator
vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
// Accumulation loop.
for (int i = 0; i < QK_MXFP4 / 2; i++) {
// Load `b_ptr`.
const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved);
const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved);
const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved);
sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved);
sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved);
}
const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved);
float b_scales[ncols_interleaved];
for (int i = 0; i < ncols_interleaved; i++) {
b_scales[i] = GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[i]);
}
const vfloat32m2_t b_e = __riscv_vle32_v_f32m2((const float *)&b_scales[0], ncols_interleaved);
const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d), ncols_interleaved);
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved);
}
__riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved);
}
}
void ggml_gemv_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc);
}
#endif
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -1122,7 +1198,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if defined __riscv_zvfh
template<int ncols_interleaved>
void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int blocklen = 1;
@ -1221,7 +1297,7 @@ void ggml_gemm_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
}
template<int ncols_interleaved>
void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int blocklen = 1;
@ -1303,7 +1379,7 @@ void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
}
template<int ncols_interleaved>
static void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
assert(n % QK_K == 0);
const int num_k_blocks = n / QK_K;
const int N_ROWS_TILE = 4;
@ -1652,6 +1728,9 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved);
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved);
// Load `dmins`.
const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved);
// We process 4 sub-blocks at once.
const int vl = ncols_interleaved * 4;
for (int j = 0; j < QK_K / 128; j++) {
@ -1733,14 +1812,10 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7],
__riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved);
const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved);
const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved);
const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved);
const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved);
const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], ncols_interleaved);
const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], ncols_interleaved);
const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], ncols_interleaved);
const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], ncols_interleaved);
sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved);
sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved);
@ -1900,7 +1975,7 @@ void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
}
template<int ncols_interleaved>
void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
static inline void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int blocklen = 1;
@ -1936,6 +2011,9 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved);
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved);
// Load `dmins`.
const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved);
// We process 4 sub-blocks at once.
const int vl = ncols_interleaved * 4;
for (int j = 0; j < QK_K / 128; j++) {
@ -2017,14 +2095,10 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7],
__riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved);
const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved);
const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved);
const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved);
const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(
__riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved);
const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], ncols_interleaved);
const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], ncols_interleaved);
const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], ncols_interleaved);
const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], ncols_interleaved);
sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved);
sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved);
@ -2185,32 +2259,16 @@ void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
}
void ggml_gemm_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_zvfh
ggml_gemm_q5_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc);
#else
ggml_gemm_q5_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
#endif
}
void ggml_gemm_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_zvfh
ggml_gemm_q5_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc);
#else
ggml_gemm_q5_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
#endif
}
void ggml_gemm_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_zvfh
ggml_gemm_q5_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc);
#else
ggml_gemm_q5_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
#endif
}
void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_zvfh
ggml_gemm_q5_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc);
#else
ggml_gemm_q5_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
#endif
}
template<int ncols_interleaved>
@ -2233,7 +2291,6 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined __riscv_v_intrinsic
const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16);
for (int y = 0; y < nr / 4; y++) {
@ -2260,8 +2317,6 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved);
const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved);
const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved);
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4], ncols_interleaved);
const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 1], ncols_interleaved);
@ -2297,9 +2352,6 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved);
}
}
return;
#endif
ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -2314,4 +2366,110 @@ void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc);
}
template<int ncols_interleaved>
static inline void ggml_gemm_mxfp4_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int blocklen = 1;
assert (n % qk == 0);
assert (nr % 4 == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(s);
UNUSED(bs);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr);
UNUSED(nc);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_mxfp4, 16);
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_mxfp4x<ncols_interleaved> * b_ptr = (const block_mxfp4x<ncols_interleaved> *) vx + (x * nb);
// 4xM Accumulators
vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved);
vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved);
vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved);
vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved);
for (int l = 0; l < nb; l++) {
// 4xM integer accumulators
vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved);
// Accumulation loop.
for (int i = 0; i < QK_MXFP4 / 2; i++) {
// Load `b_ptr`.
const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved);
const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved);
const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved);
sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved);
sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved);
sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved);
sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), ncols_interleaved);
sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved);
sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved);
sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved);
sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], __riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), ncols_interleaved);
}
// Do the final accumulation in i32 to prevent overflow.
const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved);
const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, ncols_interleaved);
const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, ncols_interleaved);
const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, ncols_interleaved);
float b_scales[ncols_interleaved];
for (int i = 0; i < ncols_interleaved; i++) {
b_scales[i] = GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[i]);
}
const vfloat32m2_t b_e = __riscv_vle32_v_f32m2((const float *)&b_scales[0], ncols_interleaved);
const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[0]), ncols_interleaved);
const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[1]), ncols_interleaved);
const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[2]), ncols_interleaved);
const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_e, GGML_FP16_TO_FP32(a_ptr[l].d[3]), ncols_interleaved);
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved);
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved);
sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved);
sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved);
}
__riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved);
__riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved);
__riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved);
__riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved);
}
}
}
void ggml_gemm_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc);
}
#endif

View File

@ -1,4 +1,3 @@
#include "ggml.h"
#define GGML_COMMON_IMPL_CPP
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
@ -725,6 +724,44 @@ static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
}
template<int ncols_interleaved>
static inline void ggml_gemv_mxfp4_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int blocklen = 1;
assert(nr == 1);
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[ncols_interleaved];
int sumi;
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_mxfp4x<ncols_interleaved> * b_ptr = (const block_mxfp4x<ncols_interleaved> *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
}
sumf[j] += sumi * GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
}
}
}
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
}
}
#endif
#if defined __riscv_zvfh
@ -1157,8 +1194,448 @@ static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC
}
}
}
template<int ncols_interleaved>
static inline void ggml_gemm_mxfp4_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int blocklen = 1;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][ncols_interleaved];
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_mxfp4x<ncols_interleaved> * b_ptr = (const block_mxfp4x<ncols_interleaved> *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4]));
}
sumf[m][j] += sumi * GGML_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++)
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
#endif
template <int M, int N>
static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
const int blocks_per_half = 64 / blocklen;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0f;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
for (int j = 0; j < ncols_interleaved; j++) {
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / blocklen;
const int qh_pos_l = qh_idx_l % blocklen;
const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / blocklen;
const int qh_pos_h = qh_idx_h % blocklen;
const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t a_l = a_ptr[l].qs[base_l + i];
const int8_t a_h = a_ptr[l].qs[base_h + i];
sumi_l += q_l * a_l;
sumi_h += q_h * a_h;
}
sumf[j] +=
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
}
template <int M, int N>
static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
const int blocks_per_half = 64 / blocklen;
const int q8_half_stride = 512;
const int q8_low_high_step = 256;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
float sumf[4][8];
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0f;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / blocklen;
const int qh_pos_l = qh_idx_l % blocklen;
const int qh_offset_l =
qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / blocklen;
const int qh_pos_h = qh_idx_h % blocklen;
const int qh_offset_h =
qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
sumi_l += q_l * q8_l;
sumi_h += q_h * q8_h;
}
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
}
template <int M, int N>
static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[ncols_interleaved];
float sum_minf[ncols_interleaved];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
constexpr int scale_stride = 32;
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
const int qh_shift = (k / (32 / blocklen)) * 2;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * blocklen + i) % 32;
const int qh_chunk = qh_idx / blocklen;
const int qh_pos = qh_idx % blocklen;
const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
}
template <int M, int N>
static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][ncols_interleaved];
float sum_minf[4][ncols_interleaved];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
constexpr int scale_stride = 32;
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
const int qh_shift = (k / (32 / blocklen)) * 2;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * blocklen + i) % 32;
const int qh_chunk = qh_idx / blocklen;
const int qh_pos = qh_idx % blocklen;
const int b_qh_offset =
qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k / (32 / blocklen)) * 256 +
(k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
}
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int m = 0; m < 4; m++) {
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
}
extern "C" {
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -1520,8 +1997,100 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
constexpr int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0f;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < 16; k++) {
// k = 0.. 7 weights 0-63 low, 64-127 high
// k = 8..15 weights 128-191 low, 192-255 high
const int base_l = (k / 8) * 128 + (k % 8) * 8;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
// qh_half: offset to the correct 32-byte half (0 or 32)
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
for (int j = 0; j < ncols_interleaved; j++) {
// Interleaved scales
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * 64 + j * 8 + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
// qh indexing with 8-byte interleaving (like q5_K)
const int qh_byte_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_byte_l / 8;
const int qh_pos_l = qh_byte_l % 8;
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_byte_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_byte_h / 8;
const int qh_pos_h = qh_byte_h % 8;
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t a_l = a_ptr[l].qs[base_l + i];
const int8_t a_h = a_ptr[l].qs[base_h + i];
sumi_l += q_l * a_l;
sumi_h += q_h * a_h;
printf("w: %d %d, b: %d %d %d\n", q_l, a_l, l_4, hi_2_l, b_ptr[l].qh[qh_offset_h]);
}
sumf[j] +=
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
}
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -1855,6 +2424,20 @@ void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t b
void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc);
}
// MXFP4
void ggml_gemv_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_mxfp4_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc);
}
#endif
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -2670,6 +3253,20 @@ void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t b
void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc);
}
// MXFP4
void ggml_gemm_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_mxfp4_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc);
}
#endif
} // extern "C"
@ -3769,6 +4366,59 @@ static int repack_iq4_nl_to_iq4_nl_Mx1_bl(struct ggml_tensor * t, const void * G
GGML_UNUSED(data_size);
}
template<int nrows_interleaved>
static block_mxfp4x<nrows_interleaved> make_block_mxfp4xMx1(block_mxfp4 * in) {
block_mxfp4x<nrows_interleaved> out;
for (int i = 0; i < nrows_interleaved; i++) {
out.e[i] = in[i].e;
}
const int end = QK_MXFP4 * nrows_interleaved / 2;
for (int i = 0; i < end; ++i) {
int src_id = i % nrows_interleaved;
int src_offset = i / nrows_interleaved;
int dst_offset = i;
out.qs[dst_offset] = in[src_id].qs[src_offset];
}
return out;
}
template<int nrows_interleaved>
static int repack_mxfp4_to_mxfp4_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
const block_mxfp4 * src = (const block_mxfp4 *)data;
block_mxfp4x<nrows_interleaved> * dst = ( block_mxfp4x<nrows_interleaved> *)t->data;
block_mxfp4 dst_tmp[nrows_interleaved];
int nrow = ggml_nrows(t);
int nblocks = t->ne[0] / QK_MXFP4;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
if (t->ne[1] % nrows_interleaved != 0) {
return -1;
}
for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_mxfp4xMx1<nrows_interleaved>(dst_tmp);
}
src += nrows_interleaved * nblocks;
}
return 0;
GGML_UNUSED(data_size);
}
#endif
static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) {
@ -4044,6 +4694,20 @@ template <> int repack<block_iq4_nl, 1, 32>(struct ggml_tensor * t, const void *
template <> int repack<block_iq4_nl, 1, 64>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_iq4_nl_to_iq4_nl_Mx1_bl<64>(t, data, data_size);
}
// MXFP4
template <> int repack<block_mxfp4, 1, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_mxfp4_to_mxfp4_Mx1_bl<8>(t, data, data_size);
}
template <> int repack<block_mxfp4, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_mxfp4_to_mxfp4_Mx1_bl<16>(t, data, data_size);
}
template <> int repack<block_mxfp4, 1, 32>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_mxfp4_to_mxfp4_Mx1_bl<32>(t, data, data_size);
}
template <> int repack<block_mxfp4, 1, 64>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_mxfp4_to_mxfp4_Mx1_bl<64>(t, data, data_size);
}
#endif
// gemv
@ -4205,6 +4869,20 @@ template <> void gemv<block_iq4_nl, 1, 32, GGML_TYPE_Q8_0>(int n, float * s, siz
template <> void gemv<block_iq4_nl, 1, 64, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
// MXFP4
template <> void gemv<block_mxfp4, 1, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_mxfp4_8x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_mxfp4, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_mxfp4_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_mxfp4, 1, 32, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_mxfp4_32x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_mxfp4, 1, 64, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_mxfp4_64x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
#endif
// gemm
@ -4366,6 +5044,20 @@ template <> void gemm<block_iq4_nl, 1, 32, GGML_TYPE_Q8_0>(int n, float * s, siz
template <> void gemm<block_iq4_nl, 1, 64, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
// MXFP4
template <> void gemm<block_mxfp4, 1, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_mxfp4_8x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_mxfp4, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_mxfp4_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_mxfp4, 1, 32, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_mxfp4_32x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_mxfp4, 1, 64, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_mxfp4_64x1_q8_0(n, s, bs, vx, vy, nr, nc);
}
#endif
class tensor_traits_base : public ggml::cpu::tensor_traits {
@ -4816,6 +5508,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 1, 16, GGML_TYPE_Q8_0> iq4_nl_16x1_q8_0;
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 1, 32, GGML_TYPE_Q8_0> iq4_nl_32x1_q8_0;
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 1, 64, GGML_TYPE_Q8_0> iq4_nl_64x1_q8_0;
// MXFP4
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 1, 8, GGML_TYPE_Q8_0> mxfp4_8x1_q8_0;
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 1, 16, GGML_TYPE_Q8_0> mxfp4_16x1_q8_0;
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 1, 32, GGML_TYPE_Q8_0> mxfp4_32x1_q8_0;
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 1, 64, GGML_TYPE_Q8_0> mxfp4_64x1_q8_0;
#endif
if (cur->type == GGML_TYPE_Q4_0) {
@ -4899,6 +5597,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
if (cur->ne[1] % 8 == 0) {
return &q5_K_8x4_q8_K;
}
}
if (ggml_cpu_has_riscv_v()) {
#if defined __riscv_zvfh
switch (__riscv_vlenb() * 8) {
@ -4976,6 +5675,18 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
}
#endif
}
} else if (cur->type == GGML_TYPE_MXFP4) {
if (ggml_cpu_has_riscv_v()) {
#if defined __riscv_zvfh
switch (__riscv_vlenb() * 8) {
case 128: { if (cur->ne[1] % 8 == 0) { return &mxfp4_8x1_q8_0; } break; }
case 256: { if (cur->ne[1] % 16 == 0) { return &mxfp4_16x1_q8_0; } break; }
case 512: { if (cur->ne[1] % 32 == 0) { return &mxfp4_32x1_q8_0; } break; }
case 1024: { if (cur->ne[1] % 64 == 0) { return &mxfp4_64x1_q8_0; } break; }
default: { return nullptr; }
}
#endif
}
}
return nullptr;

View File

@ -100,6 +100,16 @@ static_assert(sizeof(block_q5_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 1
static_assert(sizeof(block_q5_Kx32) == sizeof(ggml_half) * 64 + K_SCALE_SIZE * 32 + QK_K * 20, "wrong q5_K block size/padding");
static_assert(sizeof(block_q5_Kx64) == sizeof(ggml_half) * 128 + K_SCALE_SIZE * 64 + QK_K * 40, "wrong q5_K block size/padding");
struct block_q6_Kx8 {
ggml_half d[8];
int8_t scales[QK_K / 16 * 8];
uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2)
uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4)
};
static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8,
"wrong q6_K block size/padding");
struct block_q8_Kx4 {
float d[4]; // delta
int8_t qs[QK_K * 4]; // quants
@ -125,17 +135,23 @@ static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "w
static_assert(sizeof(block_iq4_nlx32) == 32 * sizeof(ggml_half) + QK4_NL * 16, "wrong iq4_nlx32 block size/padding");
static_assert(sizeof(block_iq4_nlx64) == 64 * sizeof(ggml_half) + QK4_NL * 32, "wrong iq4_nlx64 block size/padding");
struct block_mxfp4x4 {
uint8_t e[4];
uint8_t qs[QK_MXFP4 * 2];
template<int N> struct block_mxfp4x {
ggml_half e[N]; // deltas for `N` mxfp4 blocks
uint8_t qs[QK_MXFP4 * N / 2]; // nibbles / quants for N mxfp4 blocks
};
static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding");
struct block_mxfp4x8 {
uint8_t e[8];
uint8_t qs[QK_MXFP4 * 4];
};
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
using block_mxfp4x4 = block_mxfp4x<4>;
using block_mxfp4x8 = block_mxfp4x<8>;
using block_mxfp4x16 = block_mxfp4x<16>;
using block_mxfp4x32 = block_mxfp4x<32>;
using block_mxfp4x64 = block_mxfp4x<64>;
static_assert(sizeof(block_mxfp4x4) == 4 * sizeof(ggml_half) + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding");
static_assert(sizeof(block_mxfp4x8) == 8 * sizeof(ggml_half) + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
static_assert(sizeof(block_mxfp4x16) == 16 * sizeof(ggml_half) + QK_MXFP4 * 8, "wrong mxfp4x16 block size/padding");
static_assert(sizeof(block_mxfp4x32) == 32 * sizeof(ggml_half) + QK_MXFP4 * 16, "wrong mxfp4x32 block size/padding");
static_assert(sizeof(block_mxfp4x64) == 64 * sizeof(ggml_half) + QK_MXFP4 * 32, "wrong mxfp4x64 block size/padding");
#if defined(__cplusplus)
extern "C" {
@ -179,6 +195,7 @@ void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined __riscv_zvfh
void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -203,6 +220,10 @@ void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -227,6 +248,10 @@ void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#endif
// Native implementations
@ -293,6 +318,10 @@ void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -317,6 +346,10 @@ void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_mxfp4_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#endif
#if defined(__cplusplus)