reduce bank conflicts in filter transpose

This commit is contained in:
bssrdf 2025-11-09 00:51:51 -05:00
parent 8e0e944b70
commit 5ed2c1b787
2 changed files with 43 additions and 19 deletions

View File

@ -28,6 +28,19 @@ static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restri
}
}
constexpr uint32_t filter_swizzle_mask(uint32_t n, uint32_t m) {
if (n <= 1) return 1;
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
int count = 0;
while ((m >>= 1) != 0)
++count;
return n << count;
}
template <typename src_T, typename dst_T>
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
@ -65,7 +78,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
}
}
template <typename src_T, typename dst_T, const int rs, const unsigned int blk_c>
template <typename src_T, typename dst_T, const unsigned int mask, const int rs, const unsigned int blk_c>
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
const int64_t nmat = ne / (ne00 * ne01);
@ -74,9 +87,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
const unsigned int tx = threadIdx.x;
const unsigned int bx = blockIdx.x;
const unsigned int by = blockIdx.y;
// int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y;
// int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset
// int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y;
__shared__ src_T tile[rs*blk_c];
@ -89,8 +99,10 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
const unsigned int row = (j * blk_c + tx) % rs;
const unsigned int col = (j * blk_c + tx) / rs;
const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx;
unsigned int idx = row * blk_c + col;
idx = idx ^ ((idx & mask) >> 4);
if (src_index < ne) {
tile[row * blk_c + col] = src[src_index];
tile[idx] = src[src_index];
}
}
__syncthreads();
@ -98,7 +110,9 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
for (unsigned int j = 0; j < rs; j++){
const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx;
if(dst_index < ne){
dst[dst_index] = ggml_cuda_cast<dst_T>(tile[j*blk_c+tx]);
unsigned int idx = j*blk_c + tx;
idx = idx ^ ((idx & mask) >> 4);
dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]);
}
}
}
@ -872,25 +886,35 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM,
1) ;
if (ne01 == 25) {
NCHW2NHWC<half, half, 25, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 25, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
}else if (ne01 == 16) {
NCHW2NHWC<half, half, 16, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 16, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
}else if (ne01 == 9) {
NCHW2NHWC<half, half, 9, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 9, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 8) {
NCHW2NHWC<half, half, 8, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 8, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 7) {
NCHW2NHWC<half, half, 7, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 7, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 6) {
NCHW2NHWC<half, half, 6, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 6, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 5) {
NCHW2NHWC<half, half, 5, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 5, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 4) {
NCHW2NHWC<half, half, 4, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 4, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 3) {
NCHW2NHWC<half, half, 3, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 3, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 2) {
NCHW2NHWC<half, half, 2, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 2, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else {
dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,

View File

@ -653,8 +653,8 @@ int main(void)
int k = 0;
for (auto c : configs_sdxl_1024){
// for (auto c : configs){
// for (auto c : configs_sdxl_1024){
for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c), true);
@ -671,7 +671,7 @@ int main(void)
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
int iterations = 20;
int iterations = 0;
double run_time0;
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0);