From 949eca4cba3e08afcbfbf48c8d288cbf188b9bb1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 19:20:12 -0500 Subject: [PATCH] swizzling working, may still have room to optimize --- ggml/src/ggml-cuda/conv2d-implicit.cu | 14 ++++---------- tests/test-conv2d.cpp | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index b02224fc06..3a84935582 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -677,11 +677,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx8 = idx + 8 * BN / 2; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; - dst_ptr = reinterpret_cast(&smemoutput[idx + 8 * BN / 2]); + idx8 = idx8 ^ ((idx8 & 0b110000000000) >> 9); + idx8 = idx8 ^ ((idx8 & 0b1110000000) >> 4); + dst_ptr = reinterpret_cast(&smemoutput[idx8]); dst_ptr[0] = reg_[1]; } } @@ -698,7 +701,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, uint idx = output_lds_addr + subk*2 + j*32*BN/2; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); - // uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx])); half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { @@ -706,13 +708,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint outOffset = z * NKPQ + n * KPQ + row * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[0]; output[outOffset] = res_[0]; } else { const uint outOffset = n * KPQ + row * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[0]; output[outOffset] = res_[0]; } } @@ -721,13 +719,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint outOffset = z * NKPQ + n * KPQ + (row+1) * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[1]; output[outOffset] = res_[1]; } else { const uint outOffset = n * KPQ + (row+1) * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[1]; output[outOffset] = res_[1]; } } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 0b1b5c476f..75778b6e30 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -705,16 +705,16 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } + // for(int i = 0; i < 26*38; i++) { + // // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer);