From 55859a86aa466ff5cd6836fb85a21ebd56dde282 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 21:36:03 -0400 Subject: [PATCH] remove implicit op and related calls; replace conv_2d with conv_2d_implicit kernel --- ggml/include/ggml.h | 14 --- ggml/src/ggml-cpu/ggml-cpu.c | 6 -- ggml/src/ggml-cuda/conv2d-implicit.cu | 120 ++++++++++++++-------- ggml/src/ggml-cuda/ggml-cuda.cu | 6 +- ggml/src/ggml.c | 66 +----------- tests/test-backend-ops.cpp | 139 +++----------------------- tests/test-conv2d-implicit.cpp | 99 ++++-------------- 7 files changed, 117 insertions(+), 333 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 26d6f3332c..b7b472c56e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -513,7 +513,6 @@ extern "C" { GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_3D, GGML_OP_CONV_2D, - GGML_OP_CONV_2D_IMPLICIT, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, @@ -1983,19 +1982,6 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 - GGML_API struct ggml_tensor * ggml_conv_2d_implicitgemm( - struct ggml_context * ctx, - struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] - struct ggml_tensor * b, // input data [W, H, C, N] - int s0, // stride dimension 0 - int s1, // stride dimension 1 - int p0, // padding dimension 0 - int p1, // padding dimension 1 - int d0, // dilation dimension 0 - int d1); - // int layout); // for future - - GGML_API struct ggml_tensor * ggml_conv_3d_direct( struct ggml_context * ctx, struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6b6efebad5..c131290849 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1887,10 +1887,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_2d(params, tensor); } break; - case GGML_OP_CONV_2D_IMPLICIT: - { - ggml_compute_forward_conv_2d(params, tensor); - } break; case GGML_OP_CONV_3D: { ggml_compute_forward_conv_3d(params, tensor); @@ -2268,7 +2264,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_IM2COL_BACK: case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: - case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: @@ -2794,7 +2789,6 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_CONV_2D: - case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: { cur = GGML_IM2COL_WORK_SIZE; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 6b7efbe789..fcb053c61d 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -7,6 +7,9 @@ typedef unsigned int uint; constexpr uint WARPSIZE = 32; +#define CUDA_NCHW_2_NHWC_TILE_DIM 32 +#define CUDA_NCHW_2_NHWC_BLOCK_NM 8 +#define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 //currently not use; in future for split-k kernels @@ -23,6 +26,41 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri } } +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + int x = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; + int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset + int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + + __shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM]; + + for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ + + const unsigned int imat = blockIdx.z * CUDA_NCHW_2_NHWC_BLOCK_NM + i; + if(imat >= nmat) + break; + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(x < ne01 && y + j < ne00){ + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + tile[row][col] = src[imat*n + (y+j)*ne01 + x]; + } + } + __syncthreads(); + + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(ty + j < ne01 && tx < ne00){ + const int col = (threadIdx.y+j) ^ threadIdx.x; + dst[imat*n + (ty+j)*ne00 + tx] = ggml_cuda_cast(tile[threadIdx.x][col]); + } + } + } +} template<<>>(X_D, K_D, Y_D, P); - else if(P.layout == 1) - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); - } else{ - if(P.layout == 0) - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); - else if(P.layout == 1) - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); - } + + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) { + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { + + int id = ggml_cuda_get_device(); + + int64_t ne = P.c * P.h * P.w * P.n; + int64_t ne00 = P.c; + int64_t ne01 = P.h * P.w; + ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); + + dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); + NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); + + ne = P.c * P.r * P.s * P.k; + ne01 = P.r * P.s; + ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + + const half *X_H = input_f16.get(); + const half *K_H = kernel_f16.get(); + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; constexpr unsigned int BK_dim = 32; @@ -925,19 +977,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); - int id = ggml_cuda_get_device(); - ggml_cuda_pool_alloc x_f16(ctx.pool(id)); - - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(GGML_TYPE_F32); - GGML_ASSERT(to_fp16_cuda != nullptr); - size_t ne = P.c * P.h * P.w * P.n; - x_f16.alloc(ne); - to_fp16_cuda(X_D, x_f16.get(), ne, st); - const half *X_H = x_f16.get(); - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); conv2d_implicit_kernel - <<>>(X_H, K_D, Y_H.get(), P); + <<>>(X_H, K_H, Y_H.get(), P); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); } else{ @@ -971,28 +1013,28 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int PD_Y = p[3]; // padding_y const int DL_X = p[4]; // dilation_x const int DL_Y = p[5]; // dilation_y - const int LT = p[6]; // layout + // const int LT = p[6]; // layout - GGML_ASSERT(LT == 0 || LT == 1); + // GGML_ASSERT(LT == 0 || LT == 1); // same number of input channels - GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]); + // GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]); // No cwhn - GGML_ASSERT(p[7] == false); + GGML_ASSERT(p[6] == false); - const int IW = input->ne[LT == 0 ? 1 : 0]; // input_w - const int IH = input->ne[LT == 0 ? 2 : 1]; // input_h + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h const int OW = dst->ne[0]; // output_w const int OH = dst->ne[1]; // output_h - const int KW = kernel->ne[LT == 0 ? 1 : 0]; // kernel_w - const int KH = kernel->ne[LT == 0 ? 2 : 1]; // kernel_h - const int IC = input->ne[LT == 0 ? 0: 2]; // input_channels + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - + const int64_t total = B * OC * OH * OW; - + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; params.SC_fastdiv = init_fastdiv_values(KW*IC); params.OW_fastdiv = init_fastdiv_values(OW); @@ -1000,7 +1042,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * params.C_fastdiv = init_fastdiv_values(IC); params.RS_fastdiv = init_fastdiv_values(KW*KH); params.S_fastdiv = init_fastdiv_values(KW); - params.layout = LT; + // params.layout = LT; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 154076f38d..29fa63777b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2462,11 +2462,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_im2col_3d(ctx, dst); break; case GGML_OP_CONV_2D: - ggml_cuda_op_conv2d(ctx, dst); - break; - case GGML_OP_CONV_2D_IMPLICIT: ggml_cuda_op_conv2d_implicit(ctx, dst); - break; + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -3580,7 +3577,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: - case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bfe772697e..03c8dca3e5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -976,7 +976,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL_BACK", "IM2COL_3D", "CONV_2D", - "CONV_2D_IMPLICIT", "CONV_3D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", @@ -1020,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1081,7 +1080,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col_back(x)", "im2col_3d(x)", "conv_2d(x)", - "conv_2d_implicit(x)", "conv_3d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", @@ -1125,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4573,66 +4571,6 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } - -// ggml_conv_2d_implicitgemm - -struct ggml_tensor * ggml_conv_2d_implicitgemm( - struct ggml_context * ctx, - struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] - struct ggml_tensor * b, // input data [W, H, C, N] - int s0, // stride dimension 0 - int s1, // stride dimension 1 - int p0, // padding dimension 0 - int p1, // padding dimension 1 - int d0, // dilation dimension 0 - int d1){ - // 0: NHWC, 1:NCHW - // int layout) { - - GGML_ASSERT(a->ne[2] == b->ne[2]); - //GGML_ASSERT(a->type == b->type); - - int64_t ne[4]; - ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); - ne[2] = a->ne[3]; - ne[3] = b->ne[3]; - - struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); - - ggml_set_op_params_i32(result, 0, s0); - ggml_set_op_params_i32(result, 1, s1); - ggml_set_op_params_i32(result, 2, p0); - ggml_set_op_params_i32(result, 3, p1); - ggml_set_op_params_i32(result, 4, d0); - ggml_set_op_params_i32(result, 5, d1); - - struct ggml_tensor *ap, *bp; - if(a->type == GGML_TYPE_F16 && (a->ne[0] > 1 || a->ne[1] > 1)){ - ggml_set_op_params_i32(result, 6, 0); - ap = ggml_reshape_4d(ctx, - ggml_cont(ctx, - ggml_transpose(ctx, - ggml_reshape_3d(ctx, a, a->ne[0]*a->ne[1], a->ne[2], a->ne[3]))), - a->ne[2], a->ne[0], a->ne[1], a->ne[3]); - bp = ggml_reshape_4d(ctx, - ggml_cont(ctx, - ggml_transpose(ctx, - ggml_reshape_3d(ctx, b, b->ne[0]*b->ne[1], b->ne[2], b->ne[3]))), - b->ne[2], b->ne[0], b->ne[1], b->ne[3]); - } else{ - ggml_set_op_params_i32(result, 6, 1); - ap = a; - bp = b; - } - - result->op = GGML_OP_CONV_2D_IMPLICIT; - result->src[0] = ap; - result->src[1] = bp; - - return result; -} - // ggml_conv_3d struct ggml_tensor * ggml_conv_3d_direct( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b3948a0bbf..a7aba2b447 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4191,94 +4191,6 @@ struct test_conv_2d : public test_case { } }; -// CONV_2D_IMPLICIT -struct test_conv_2d_implicit : public test_case { - const std::array ne_input; - const std::array ne_kernel; - const ggml_type type_kernel; - const int stride0; - const int stride1; - const int padding0; - const int padding1; - const int dilation0; - const int dilation1; - // Whether the inputs are contiguous in the channel dim or the width dim - const bool cwhn; - - - - std::string vars() override { - return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); - } - - double max_nmse_err() override { - return 5e-4; - } - - uint64_t op_flops(ggml_tensor * t) override { - GGML_UNUSED(t); - // Just counting matmul costs: - // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops - - // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) - auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { - return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; - }; - - int64_t W = ne_input[0]; - int64_t H = ne_input[1]; - int64_t KW = ne_kernel[0]; - int64_t KH = ne_kernel[1]; - int64_t Cin = ne_kernel[2]; - int64_t Cout = ne_kernel[3]; - int64_t N = ne_input[3]; - int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0); - int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0); - - int64_t K = Cout; - int64_t CRS = Cin * KH * KW; - int64_t NPQ = N * OH * OW; - - return K * NPQ * (2 * CRS - 1); - } - - test_conv_2d_implicit(std::array ne_input = { 64, 64, 16, 1 }, - std::array ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1, - int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) : - ne_input(ne_input), - ne_kernel(ne_kernel), - type_kernel(type_kernel), - stride0(stride0), - stride1(stride1), - padding0(padding0), - padding1(padding1), - dilation0(dilation0), - dilation1(dilation1), - cwhn(cwhn) {} - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); - ggml_set_name(input, "input"); - - ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); - ggml_set_name(kernel, "kernel"); - - // if (cwhn) { - // // change memory layout to channel-most-contiguous (CWHN), - // // then permute it back so NE matches the original input - // input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); - // input = ggml_permute(ctx, input, 2, 0, 1, 3); - // kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); - // kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); - // } - - ggml_tensor * out = - ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1); - ggml_set_name(out, "out"); - return out; - } -}; - // GGML_OP_CONV_2D_DW struct test_conv_2d_dw : public test_case { const std::array ne_input; @@ -5941,30 +5853,6 @@ static std::vector> make_test_cases_eval() { } } - for (uint32_t s0 : { 1, 3 }) { - for (uint32_t p1 : { 2, 5 }) { - for (uint32_t Cin : { 1, 25 }) { - for (uint32_t Cout : { 1, 12 }) { - for (uint32_t KH : { 1, 2, 3, 11 }) { - for (uint32_t KW : { 1, 2, 3, 11 }) { - for (uint32_t H : { 1, 133 }) { - for (uint32_t W : { 1, 141 }) { - if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 && - calc_conv_output_size(H, KH, s1, p1, d1) > 0) { - for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_conv_2d_implicit( - { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false)); - } - } - } - } - } - } - } - } - } - } - // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) @@ -6732,16 +6620,6 @@ static std::vector> make_test_cases_perf() { } } - for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (auto act_case : cases) { - // Direct CONV_2D - test_cases.emplace_back(new test_conv_2d_implicit( - { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, - { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, - kernel_type, 1, 1, 0, 0, 1, 1, false)); - } - } - // Stable-diffusion layers std::map idx_sd{ { "iw", 0 }, @@ -6788,7 +6666,7 @@ static std::vector> make_test_cases_perf() { uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; - test_cases.emplace_back(new test_conv_2d_implicit( + test_cases.emplace_back(new test_conv_2d( { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); @@ -6801,12 +6679,25 @@ static std::vector> make_test_cases_perf() { uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; - test_cases.emplace_back(new test_conv_2d_implicit( + test_cases.emplace_back(new test_conv_2d( { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false)); } + // for (auto act_case : cases_sd) { + // GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + // GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + // uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + // uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + // test_cases.emplace_back(new test_conv_2d_implicit( + // { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + // { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + // GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); + // } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 98b8b0e449..7b7a32d9f6 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -239,49 +239,6 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { return gf; } -struct ggml_cgraph * build_graph_2(const test_model& model) { - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; - - - // recalculate for avoid fragmentation - // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - // ggml_set_name(conv2d_res, "conv2d_res"); - // ggml_build_forward_expand(gf, conv2d_res); - // int64_t *ne = conv2d_res->ne; - // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - - struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - // struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - ggml_set_name(wino_res, "wino_res"); - ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; - // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - ggml_free(ctx0); - return gf; -} - - - std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, build_graph_t build_graph, int iters, double *t) { @@ -352,10 +309,10 @@ int main(void) // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), - std::make_tuple(128,3,768,1024,3,3), + // std::make_tuple(128,3,768,1024,3,3), // std::make_tuple(256,128,768,1024,1,1), // std::make_tuple(512,256,384,512,1,1), // std::make_tuple(1280,640,52,76,3,3), @@ -389,7 +346,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -418,51 +375,31 @@ int main(void) ggml_gallocr_free(allocr); - allocr = NULL; - - allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - - //create the worst case graph for memory usage estimation - gf = build_graph_2(model); - - // compute the required memory - ggml_gallocr_reserve(allocr, gf); - size_t mem_size2 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - - struct ggml_cgraph * gf_res_2 = NULL; - - double run_time2; - std::vector wino_data = compute_graph(model, allocr, build_graph_2, iterations, &run_time2); - - if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | direct TIME | direct VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); - fprintf(stderr, "| --- | --- | --- | --- | --- | --- | --- \n"); + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } - fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), run_time0, mem_size0/1024.0f/1024.0f, - run_time1, mem_size1/1024.0f/1024.0f, - run_time2, mem_size2/1024.0f/1024.0f); + run_time1, mem_size1/1024.0f/1024.0f + ); // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // // float diff = fabs(conv2d_data[i] - wino_data[i]); - // float diff = fabs(im2col_data[i] - wino_data[i]); - // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // wino_data[i], diff, diff1, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer);