the special filter transpose NCHW2NHWC is broken, disable it and use the other less optimized one

This commit is contained in:
bssrdf 2025-11-15 22:37:52 -05:00
parent fa7dd684bf
commit 3591e83db9
3 changed files with 81 additions and 82 deletions

View File

@ -82,47 +82,48 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
} }
} }
template <typename src_T, typename dst_T, const unsigned int mask, const int rs, const unsigned int blk_c> //*** broken, has bugs ***
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ // template <typename src_T, typename dst_T, const unsigned int mask, const int rs, const unsigned int blk_c>
// static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
const int64_t nmat = ne / (ne00 * ne01); // const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01; // const int64_t n = ne00 * ne01;
const unsigned int tx = threadIdx.x; // const unsigned int tx = threadIdx.x;
const unsigned int bx = blockIdx.x; // const unsigned int bx = blockIdx.x;
const unsigned int by = blockIdx.y; // const unsigned int by = blockIdx.y;
__shared__ src_T tile[rs*blk_c]; // __shared__ src_T tile[rs*blk_c];
#pragma unroll // #pragma unroll
for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ // for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){
const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; // const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i;
if(imat >= nmat) // if(imat >= nmat)
break; // break;
#pragma unroll // #pragma unroll
for (unsigned int j = 0; j < rs; j++){ // for (unsigned int j = 0; j < rs; j++){
const unsigned int row = (j * blk_c + tx) % rs; // const unsigned int row = (j * blk_c + tx) % rs;
const unsigned int col = (j * blk_c + tx) / rs; // const unsigned int col = (j * blk_c + tx) / rs;
const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; // const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx;
unsigned int idx = row * blk_c + col; // unsigned int idx = row * blk_c + col;
idx = idx ^ ((idx & mask) >> 4); // idx = idx ^ ((idx & mask) >> 4);
if (src_index < ne) { // if (src_index < ne) {
tile[idx] = src[src_index]; // tile[idx] = src[src_index];
} // }
} // }
__syncthreads(); // __syncthreads();
#pragma unroll // #pragma unroll
for (unsigned int j = 0; j < rs; j++){ // for (unsigned int j = 0; j < rs; j++){
const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; // const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx;
if(dst_index < ne){ // if(dst_index < ne){
unsigned int idx = j*blk_c + tx; // unsigned int idx = j*blk_c + tx;
idx = idx ^ ((idx & mask) >> 4); // idx = idx ^ ((idx & mask) >> 4);
dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]); // dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]);
} // }
} // }
} // }
} // }
@ -956,10 +957,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// if(block_k == num_block_tiles_k) // if(block_k == num_block_tiles_k)
// break; // break;
if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ // if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx, // printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx,
block_krs, num_block_tiles_k, num_block_tiles_krs); // block_krs, num_block_tiles_k, num_block_tiles_krs);
} // }
// if (block_k != num_block_tiles_k){ // if (block_k != num_block_tiles_k){
if (block_krs != num_block_tiles_krs){ if (block_krs != num_block_tiles_krs){
@ -1025,7 +1026,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
} }
} }
// if(threadIdx.x >= 8 && threadIdx.x < 12 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ // if(threadIdx.x >= 8 && threadIdx.x < 12 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){
// printf("A %d, %d, %d: %f, %f \n", block_krs, mma_k, threadIdx.x, // printf("A %d, %d, %d: %f, %f \n", block_krs, mma_k, threadIdx.x,
// __half2float(A_register_[1][mma_k][0]), // __half2float(A_register_[1][mma_k][0]),
@ -1153,13 +1153,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// __half2float(acc_register_[1][1][2]), // __half2float(acc_register_[1][1][2]),
// __half2float(acc_register_[1][1][3])); // __half2float(acc_register_[1][1][3]));
// } // }
if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){
printf(" %u, %f, %f, %f, %f\n", blockIdx.z, // printf(" %u, %f, %f, %f, %f\n", blockIdx.z,
__half2float(acc_register_[0][1][0]), // __half2float(acc_register_[0][1][0]),
__half2float(acc_register_[0][1][1]), // __half2float(acc_register_[0][1][1]),
__half2float(acc_register_[0][1][2]), // __half2float(acc_register_[0][1][2]),
__half2float(acc_register_[0][1][3])); // __half2float(acc_register_[0][1][3]));
} // }
// reuse smem // reuse smem
half *smemoutput = shmem; half *smemoutput = shmem;
@ -1341,42 +1341,42 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM,
1) ; 1) ;
if (ne01 == 25) { // if (ne01 == 25) {
constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 25, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 25, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
}else if (ne01 == 16) { // }else if (ne01 == 16) {
constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 16, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 16, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
}else if (ne01 == 9) { // }else if (ne01 == 9) {
constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 9, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 9, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 8) { // } else if (ne01 == 8) {
constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 8, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 8, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 7) { // } else if (ne01 == 7) {
constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 7, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 7, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 6) { // } else if (ne01 == 6) {
constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 6, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 6, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 5) { // } else if (ne01 == 5) {
constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 5, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 5, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 4) { // } else if (ne01 == 4) {
constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 4, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 4, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 3) { // } else if (ne01 == 3) {
constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 3, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 3, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else if (ne01 == 2) { // } else if (ne01 == 2) {
constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); // constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 2, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); // NCHW2NHWC<half, half, mask, 2, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} else { // } else {
dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
NCHW2NHWC<half, half><<<dimGrid2, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01); NCHW2NHWC<half, half><<<dimGrid2, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
} // }
} }
const half *X_H = input_f16.get(); const half *X_H = input_f16.get();

View File

@ -5826,7 +5826,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (uint32_t s0 : { 1, 3 }) { for (uint32_t s0 : { 1, 3 }) {
for (uint32_t p1 : { 2, 5 }) { for (uint32_t p1 : { 2, 5 }) {
for (uint32_t Cin : { 1, 25 }) { for (uint32_t Cin : { 1, 25, 48 }) {
for (uint32_t Cout : { 1, 12 }) { for (uint32_t Cout : { 1, 12 }) {
for (uint32_t KH : { 1, 2, 3, 11 }) { for (uint32_t KH : { 1, 2, 3, 11 }) {
for (uint32_t KW : { 1, 2, 3, 11 }) { for (uint32_t KW : { 1, 2, 3, 11 }) {

View File

@ -752,7 +752,6 @@ int main(void)
// for(int i = 0; i < conv2d_data.size(); i++) { // for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]); // float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) { // // if(diff > 0.5) {
// // if(diff > 2.0) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n", // printf("(%7.3f, %7.3f, %.2f, %d) \n",
// im2col_data[i], conv2d_data[i], // im2col_data[i], conv2d_data[i],
// diff, i); // diff, i);