From 353893046633e2d6e803bb1cbef1632c3d1fa429 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Tue, 23 Sep 2025 14:11:43 +0800 Subject: [PATCH] ggml-cpu: rework mxfp4 Signed-off-by: Aaron Teo --- ggml/src/ggml-cpu/arch/s390/quants.c | 42 ++++------------------------ 1 file changed, 5 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index daa2142c60..7930bb42ec 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -280,46 +280,13 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #if defined(__VXE__) || defined(__VXE2__) const int8x16_t v_k = vec_xl(0, kvalues_mxfp4); - const uint8x16_t v_m = vec_splats((uint8_t)0x0F); - - for (; ib + 1 < nb; ib += 2) { - const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - const uint8x16_t v_x0 = vec_xl(0, x0->qs); - const uint8x16_t v_x1 = vec_xl(0, x1->qs); - - int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); - int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); - int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); - int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); - - v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l); - v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h); - v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l); - v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h); - - const int8x16_t v_y0l = vec_xl(0, y0->qs); - const int8x16_t v_y0h = vec_xl(QK8_0/2, y0->qs); - const int8x16_t v_y1l = vec_xl(0, y1->qs); - const int8x16_t v_y1h = vec_xl(QK8_0/2, y1->qs); - - const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0l), v_x0h, v_y0h); - const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y1l), v_x1h, v_y1h); - - sumf += - GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy0) + - GGML_E8M0_TO_FP32(x1->e) * GGML_CPU_FP16_TO_FP32(y1->d) * vec_hsum_i32x4(v_xy1); - } + const uint8x16_t v_m = vec_splats((const uint8_t)0x0F); for (; ib < nb; ++ib) { - const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_mxfp4 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; const uint8x16_t v_x = vec_xl(0, x0->qs); - int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); @@ -331,7 +298,8 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); - sumf += GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy); + const float scale = GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d); + sumf += scale * vec_hsum_i32x4(v_xy); } *s = sumf;