fixed a bug in the special filter transpose NCHW2NHWC; still failing for channel number < 32

This commit is contained in:
bssrdf 2025-11-16 09:31:38 -05:00
parent 721fa41076
commit bccd869968
3 changed files with 24 additions and 21 deletions

View File

@ -93,6 +93,8 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
const unsigned int bx = blockIdx.x;
const unsigned int by = blockIdx.y;
const unsigned int blk = (bx+1) * blk_c <= ne00 ? blk_c : ne00 - bx * blk_c;
__shared__ src_T tile[rs*blk_c];
#pragma unroll
@ -103,12 +105,12 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
break;
#pragma unroll
for (unsigned int j = 0; j < rs; j++){
const unsigned int row = (j * blk_c + tx) % rs;
const unsigned int col = (j * blk_c + tx) / rs;
const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx;
const unsigned int row = (j * blk + tx) % rs;
const unsigned int col = (j * blk + tx) / rs;
const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk + tx;
// const unsigned int src_index = imat*n + rs*ne00 + bx * blk_c + j * blk_c + tx;
unsigned int idx = row * blk_c + col;
// idx = idx ^ ((idx & mask) >> 4);
idx = idx ^ ((idx & mask) >> 4);
if (src_index < ne) {
tile[idx] = src[src_index];
}
@ -117,9 +119,9 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
#pragma unroll
for (unsigned int j = 0; j < rs; j++){
const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx;
if(dst_index < ne){
if(dst_index < ne && tx < blk){
unsigned int idx = j*blk_c + tx;
// idx = idx ^ ((idx & mask) >> 4);
idx = idx ^ ((idx & mask) >> 4);
dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]);
}
}
@ -1338,17 +1340,17 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
// ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id));
if (ne01 > 1){
kernel_f16.alloc(ne);
// kernel_f16.alloc(ne);
// dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C,
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM,
// 1) ;
// if (ne01 == 25) {
// constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C);
// NCHW2NHWC<half, half, mask, 25, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
// }else if (ne01 == 16) {
// } else if (ne01 == 16) {
// constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C);
// NCHW2NHWC<half, half, mask, 16, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
// }else if (ne01 == 9) {
// } else if (ne01 == 9) {
// constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C);
// NCHW2NHWC<half, half, mask, 9, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
// } else if (ne01 == 8) {

View File

@ -5826,7 +5826,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (uint32_t s0 : { 1, 3 }) {
for (uint32_t p1 : { 2, 5 }) {
for (uint32_t Cin : { 1, 25, 48 }) {
for (uint32_t Cin : { 1, 16, 25, 48 }) {
for (uint32_t Cout : { 1, 12 }) {
for (uint32_t KH : { 1, 2, 3, 11 }) {
for (uint32_t KW : { 1, 2, 3, 11 }) {

View File

@ -319,7 +319,7 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(1920,640,32,32,3,3)
std::make_tuple(1280,1280,16,16,3,3),
// std::make_tuple(1280,1280,16,16,3,3),
// std::make_tuple(32,12,141,133,3,3),
// std::make_tuple(32,6,141,133,3,3),
// std::make_tuple(32,12,141,121,3,3),
@ -328,7 +328,8 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(320,9,16,16,3,3), //working
// std::make_tuple(320,12,16,16,3,3), //working
// std::make_tuple(256,12,16,16,3,3), //working
// std::make_tuple(32,12,16,16,3,3), //not working
// std::make_tuple(32,12,16,16,3,3), //not working
std::make_tuple(16,12,16,16,3,3), //not working
// std::make_tuple(48,12,16,16,3,3), // not working
// std::make_tuple(96,12,16,16,3,3), //not working
// std::make_tuple(64,12,16,16,3,3), //working
@ -749,15 +750,15 @@ int main(void)
// int i = 2048;
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
// im2col_data[i], conv2d_data[i],
// diff, i);
// // break;
// // }
// }
for(int i = 0; i < conv2d_data.size(); i++) {
float diff = fabs(im2col_data[i] - conv2d_data[i]);
// if(diff > 0.5) {
printf("(%7.3f, %7.3f, %.2f, %d) \n",
im2col_data[i], conv2d_data[i],
diff, i);
// break;
// }
}
ggml_free(model.ctx);
ggml_backend_buffer_free(model.buffer);