From 495c363267b9113a5e3f753340279f6008e39f81 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Mon, 30 Mar 2026 00:48:19 +0200 Subject: [PATCH 1/8] ds_read_b128 for q4_0 and q4_1 mmq kernels Current for loop generates ds_read_b32 instructions with hip compiler, the new solution generates ds_read_b128 instructions for the same operation, saving some LDS bandwidth. Tested on MI50 and RX6800XT, its faster on both. --- ggml/src/ggml-cuda/mmq.cuh | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 255e59f6fc..5d1813e289 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -384,15 +384,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( int u[2*VDR_Q4_0_Q8_1_MMQ]; +#if defined(GGML_USE_HIP) + const int4 vec0 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs]); + const int4 vec1 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs + QI4_0]); + + u[0] = vec0.x; u[2] = vec0.y; u[4] = vec0.z; u[6] = vec0.w; + u[1] = vec1.x; u[3] = vec1.y; u[5] = vec1.z; u[7] = vec1.w; +#else + #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; } +#endif sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } @@ -487,12 +497,20 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( int u[2*VDR_Q4_1_Q8_1_MMQ]; +#if defined(GGML_USE_HIP) + const int4 vec0 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs]); + const int4 vec1 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs + QI4_0]); + + u[0] = vec0.x; u[2] = vec0.y; u[4] = vec0.z; u[6] = vec0.w; + u[1] = vec1.x; u[3] = vec1.y; u[5] = vec1.z; u[7] = vec1.w; +#else + #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; } - +#endif sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); From cc9ea913bc3eb3f9ed3337bf646d9a141f3b38df Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Mon, 30 Mar 2026 14:24:47 +0200 Subject: [PATCH 2/8] Vectorized lds load update: used ggml_cuda_get_max_cpy_bytes and ggml_cuda_memcpy_1 functions for generic implementation --- ggml/src/ggml-cuda/mmq.cuh | 40 +++++++++++++------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 5d1813e289..0563e2e207 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -379,25 +379,18 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; -#if defined(GGML_USE_HIP) - const int4 vec0 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs]); - const int4 vec1 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs + QI4_0]); + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); - u[0] = vec0.x; u[2] = vec0.y; u[4] = vec0.z; u[6] = vec0.w; - u[1] = vec1.x; u[3] = vec1.y; u[5] = vec1.z; u[7] = vec1.w; -#else + int4 vec0, vec1; + ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec0, &y_qs[j*MMQ_TILE_Y_K + kyqs]); + ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec1, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]); -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; - } -#endif + u[0]=vec0.x; u[2]=vec0.y; u[4]=vec0.z; u[6]=vec0.w; + u[1]=vec1.x; u[3]=vec1.y; u[5]=vec1.z; u[7]=vec1.w; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, @@ -492,25 +485,19 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_1_Q8_1_MMQ]; -#if defined(GGML_USE_HIP) - const int4 vec0 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs]); - const int4 vec1 = *((const int4 *) &y_qs[j * MMQ_TILE_Y_K + kyqs + QI4_0]); + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + + int4 vec0, vec1; + ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec0, &y_qs[j*MMQ_TILE_Y_K + kyqs]); + ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec1, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1]); - u[0] = vec0.x; u[2] = vec0.y; u[4] = vec0.z; u[6] = vec0.w; - u[1] = vec1.x; u[3] = vec1.y; u[5] = vec1.z; u[7] = vec1.w; -#else + u[0]=vec0.x; u[2]=vec0.y; u[4]=vec0.z; u[6]=vec0.w; + u[1]=vec1.x; u[3]=vec1.y; u[5]=vec1.z; u[7]=vec1.w; -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; - } -#endif sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -4113,3 +4100,4 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); + From 62c2f8f7c00a55dcce89031d92a14d71ccd09114 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Mon, 30 Mar 2026 17:06:12 +0200 Subject: [PATCH 3/8] Explicit for loop in mmq, renamed vec into tmp --- ggml/src/ggml-cuda/mmq.cuh | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 0563e2e207..3a0ee99710 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -382,15 +382,17 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + int tmp0[4], tmp1[4]; - constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); +#pragma unroll + for (int l0 = 0; l0 < 4*sizeof(int)/max_cpy; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs]) + l0*max_cpy); + ggml_cuda_memcpy_1(tmp1 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]) + l0*max_cpy); + } - int4 vec0, vec1; - ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec0, &y_qs[j*MMQ_TILE_Y_K + kyqs]); - ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec1, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]); - - u[0]=vec0.x; u[2]=vec0.y; u[4]=vec0.z; u[6]=vec0.w; - u[1]=vec1.x; u[3]=vec1.y; u[5]=vec1.z; u[7]=vec1.w; + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, @@ -490,13 +492,16 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( int u[2*VDR_Q4_1_Q8_1_MMQ]; constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); - - int4 vec0, vec1; - ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec0, &y_qs[j*MMQ_TILE_Y_K + kyqs]); - ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec1, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1]); + int tmp0[4], tmp1[4]; - u[0]=vec0.x; u[2]=vec0.y; u[4]=vec0.z; u[6]=vec0.w; - u[1]=vec1.x; u[3]=vec1.y; u[5]=vec1.z; u[7]=vec1.w; +#pragma unroll + for (int l0 = 0; l0 < 4*sizeof(int)/max_cpy; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs]) + l0*max_cpy); + ggml_cuda_memcpy_1(tmp1 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]) + l0*max_cpy); + } + + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, From 0bcddd2164ac87a4608270b38bc8ec6447e4265a Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Mon, 30 Mar 2026 19:27:41 +0200 Subject: [PATCH 4/8] Fixed max_cpy usage in the loading loop --- ggml/src/ggml-cuda/mmq.cuh | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 3a0ee99710..a2098248ee 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -382,13 +382,16 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + int tmp0[4], tmp1[4]; -#pragma unroll - for (int l0 = 0; l0 < 4*sizeof(int)/max_cpy; ++l0) { - ggml_cuda_memcpy_1(tmp0 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs]) + l0*max_cpy); - ggml_cuda_memcpy_1(tmp1 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]) + l0*max_cpy); + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]); } u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; @@ -489,15 +492,17 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); - int u[2*VDR_Q4_1_Q8_1_MMQ]; + int u[2*VDR_Q4_0_Q8_1_MMQ]; constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + int tmp0[4], tmp1[4]; -#pragma unroll - for (int l0 = 0; l0 < 4*sizeof(int)/max_cpy; ++l0) { - ggml_cuda_memcpy_1(tmp0 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs]) + l0*max_cpy); - ggml_cuda_memcpy_1(tmp1 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]) + l0*max_cpy); + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]); } u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; From 5d7df5df44fd34fbc4c2d505195f1ea7f20c4c56 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Mon, 30 Mar 2026 19:32:12 +0200 Subject: [PATCH 5/8] Fixed typo in q4_1 kernel --- ggml/src/ggml-cuda/mmq.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index a2098248ee..9618579424 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -492,7 +492,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); - int u[2*VDR_Q4_0_Q8_1_MMQ]; + int u[2*VDR_Q4_1_Q8_1_MMQ]; constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); constexpr int mcpy_int = max_cpy / sizeof(int); From d3065542f0a73309cfbb0ecfb3071d84f8826145 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Wed, 1 Apr 2026 13:23:42 +0200 Subject: [PATCH 6/8] Update ggml/src/ggml-cuda/mmq.cuh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmq.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 9618579424..209555ce39 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -384,7 +384,8 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( int u[2*VDR_Q4_0_Q8_1_MMQ]; constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); - constexpr int mcpy_int = max_cpy / sizeof(int); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); int tmp0[4], tmp1[4]; From 777f5943a41f31c45823fb6572aa784f8ed1fa06 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Wed, 1 Apr 2026 13:24:02 +0200 Subject: [PATCH 7/8] Update ggml/src/ggml-cuda/mmq.cuh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmq.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 209555ce39..5424ba79b5 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -401,7 +401,6 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } } } From fbc4cfcdde08bd87086e7d5d35ddec412ca49f6a Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Wed, 1 Apr 2026 13:24:19 +0200 Subject: [PATCH 8/8] Update ggml/src/ggml-cuda/mmq.cuh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmq.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 5424ba79b5..1719613990 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -495,7 +495,9 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( int u[2*VDR_Q4_1_Q8_1_MMQ]; constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); - constexpr int mcpy_int = max_cpy / sizeof(int); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + int tmp0[4], tmp1[4];