diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 64a521f616..5307e58ed7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -672,12 +672,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { - const int output_sts_offset = output_sts_addr + mma_m * MMA_M * BN / 2 - i * mma_tiles_per_warp_n/2 * MMA_N; for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint idx = output_sts_offset + mma_n * MMA_N; - // mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx = output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; @@ -688,24 +687,40 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, __syncthreads(); const unsigned int m_i_wn = m_idx + i * WN / 2; #pragma unroll - for (int subk = 0; subk < WN / 2; ++subk){ - const uint row = m_i_wn + subk; + for (int subk = 0; subk < WN / 4; ++subk){ + const uint row = m_i_wn + subk*2; #pragma unroll for (int j = 0; j < 4; ++j){ const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + uint idx = output_lds_addr + subk*2 + j*32*BN/2; + idx = idx ^ ((idx & 0b1110000000) >> 4); + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); if (n < param.n && row < param.k && col < PQ) { - uint idx = output_lds_addr + subk + j*32*BN/2; - idx = idx ^ ((idx & 0b1110000000) >> 4); if constexpr (ksplit > 0) { const uint outOffset = z * NKPQ + n * KPQ + row * PQ + col; - output[outOffset] = smemoutput[idx]; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[0]; } else { const uint outOffset = n * KPQ + row * PQ + col; - output[outOffset] = smemoutput[idx]; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[0]; + } + } + if (n < param.n && row+1 < param.k && col < PQ) { + if constexpr (ksplit > 0) { + const uint outOffset = z * NKPQ + + n * KPQ + + (row+1) * PQ + col; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[1]; + } else { + const uint outOffset = n * KPQ + (row+1) * PQ + col; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[1]; } } } @@ -803,6 +818,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + + static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free"); + const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; @@ -812,7 +830,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; const unsigned int ksplit = 8; - if (BlocksM * BlocksN < nsm && P.c > 8 * ksplit) { + if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); cudaFuncSetAttribute(conv2d_implicit_kernel, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 177288c811..16861c71c9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5848,6 +5848,14 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_conv_2d( { 24, 24, 32, 1 }, { 3, 3, 32, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 96, 1 }, { 3, 3, 96, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 1 }, { 3, 3, 128, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index f671c4606a..0b1b5c476f 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -305,7 +305,7 @@ int main(void) // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), // std::make_tuple(4,320,64,96,3,3), @@ -538,108 +538,108 @@ int main(void) //1024x1024 - std::make_tuple(4,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,4,128,128,3,3), - std::make_tuple(4,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,4,128,128,3,3), + // std::make_tuple(4,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,4,128,128,3,3), + // std::make_tuple(4,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,4,128,128,3,3), }; @@ -663,7 +663,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -705,16 +705,16 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - // // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < 26*38; i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer);