From 414bb8d9ed02050c6f37b4cfe57ce70a78434bfe Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 23:20:46 -0500 Subject: [PATCH] further reduce index swizzling computation cycles --- ggml/src/ggml-cuda/conv2d-implicit.cu | 11 +++++++---- tests/test-conv2d.cpp | 5 +++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 2fd244389d..0ec9dca1bd 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -691,15 +691,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #pragma unroll for (int subk = 0; subk < WN / 4; ++subk){ const uint row = m_i_wn + subk*2; + uint idx = output_lds_addr + subk*2; + idx = idx ^ ((idx & 0b110000000000) >> 9); + idx = idx ^ ((idx & 0b1110000000) >> 4); #pragma unroll for (int j = 0; j < 4; ++j){ const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - uint idx = output_lds_addr + subk*2 + j*32*BN/2; - idx = idx ^ ((idx & 0b110000000000) >> 9); - idx = idx ^ ((idx & 0b1110000000) >> 4); - uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx])); + // uint idx = output_lds_addr + subk*2 + j*32*BN/2; + // idx = idx ^ ((idx & 0b110000000000) >> 9); + // idx = idx ^ ((idx & 0b1110000000) >> 4); + uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { if constexpr (ksplit > 0) { diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 720ddbf269..57edc02474 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -325,8 +325,8 @@ int main(void) std::make_tuple(512,256,416,608,3,3), std::make_tuple(256,128,832,1216,3,3), std::make_tuple(256,256,832,1216,3,3), - // std::make_tuple(320,256,1024,1920) std::make_tuple(32,64,58,58,3,3) + // std::make_tuple(320,256,1024,1920) }; std::vector> configs_sdxl_512 = { //512x512 @@ -648,7 +648,8 @@ int main(void) int k = 0; - for (auto c : configs_sdxl_1024){ + for (auto c : configs_sdxl_512){ + // for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true);