From 495c363267b9113a5e3f753340279f6008e39f81 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Mon, 30 Mar 2026 00:48:19 +0200 Subject: [PATCH] ds_read_b128 for q4_0 and q4_1 mmq kernels Current for loop generates ds_read_b32 instructions with hip compiler, the new solution generates ds_read_b128 instructions for the same operation, saving some LDS bandwidth. Tested on MI50 and RX6800XT, its faster on both. --- ggml/src/ggml-cuda/mmq.cuh | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 255e59f6fc..5d1813e289 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -384,15 +384,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( int u[2*VDR_Q4_0_Q8_1_MMQ]; +#if defined(GGML_USE_HIP) + const int4 vec0 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs]); + const int4 vec1 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs + QI4_0]); + + u[0] = vec0.x; u[2] = vec0.y; u[4] = vec0.z; u[6] = vec0.w; + u[1] = vec1.x; u[3] = vec1.y; u[5] = vec1.z; u[7] = vec1.w; +#else + #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; } +#endif sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } @@ -487,12 +497,20 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( int u[2*VDR_Q4_1_Q8_1_MMQ]; +#if defined(GGML_USE_HIP) + const int4 vec0 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs]); + const int4 vec1 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs + QI4_0]); + + u[0] = vec0.x; u[2] = vec0.y; u[4] = vec0.z; u[6] = vec0.w; + u[1] = vec1.x; u[3] = vec1.y; u[5] = vec1.z; u[7] = vec1.w; +#else + #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; } - +#endif sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);