Explicit for loop in mmq, renamed vec into tmp

This commit is contained in:
iacopPBK 2026-03-30 17:06:12 +02:00
parent cc9ea913bc
commit 62c2f8f7c0
1 changed files with 18 additions and 13 deletions

View File

@ -382,15 +382,17 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
int u[2*VDR_Q4_0_Q8_1_MMQ];
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
int tmp0[4], tmp1[4];
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
#pragma unroll
for (int l0 = 0; l0 < 4*sizeof(int)/max_cpy; ++l0) {
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs]) + l0*max_cpy);
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]) + l0*max_cpy);
}
int4 vec0, vec1;
ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec0, &y_qs[j*MMQ_TILE_Y_K + kyqs]);
ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec1, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]);
u[0]=vec0.x; u[2]=vec0.y; u[4]=vec0.z; u[6]=vec0.w;
u[1]=vec1.x; u[3]=vec1.y; u[5]=vec1.z; u[7]=vec1.w;
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
@ -490,13 +492,16 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
int u[2*VDR_Q4_1_Q8_1_MMQ];
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
int4 vec0, vec1;
ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec0, &y_qs[j*MMQ_TILE_Y_K + kyqs]);
ggml_cuda_memcpy_1<4*sizeof(int), max_cpy>(&vec1, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1]);
int tmp0[4], tmp1[4];
u[0]=vec0.x; u[2]=vec0.y; u[4]=vec0.z; u[6]=vec0.w;
u[1]=vec1.x; u[3]=vec1.y; u[5]=vec1.z; u[7]=vec1.w;
#pragma unroll
for (int l0 = 0; l0 < 4*sizeof(int)/max_cpy; ++l0) {
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs]) + l0*max_cpy);
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0*max_cpy, (&y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0]) + l0*max_cpy);
}
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,