From 0ca43582e853014a79c71843b86144fe245353a3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 8 Oct 2025 13:52:56 -0400 Subject: [PATCH] reorder register tile loop --- ggml/src/ggml-cuda/conv2d-implicit.cu | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index cae35280c0..f2af27a7fb 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -183,12 +183,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } +// #pragma unroll +// for (int i = 0; i < 8; ++i){ +// auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); +// #pragma unroll +// for (int j = 0; j < 8; ++j){ +// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j]; +// } +// } #pragma unroll - for (int i = 0; i < 8; ++i){ - auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); + for (int j = 0; j < 8; ++j){ + // auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); #pragma unroll - for (int j = 0; j < 8; ++j){ - output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j]; + for (int i = 0; i < 8; ++i){ + output_frag[j][i] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; } } } @@ -215,7 +223,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (int i = 0; i < 8; ++i){ #pragma unroll for (int j = 0; j < 8; ++j){ - output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; + output_frag[i][j] += ggml_cuda_cast(weight_frag[1][j]) * input_frag[1][i]; } } } @@ -240,15 +248,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #pragma unroll for (int subj = 0; subj < 4; ++subj){ // output sts - smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; + smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj]; } } __syncthreads(); #pragma unroll for (int subk = 0; subk < 16; ++subk){ - int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; - if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32; + if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow) output[outOffset] = smemoutput[output_lds_addr + subk * 32]; } }