From 0a64ea8ff84add8bafc7464fb0018d223ae80e79 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 2 Nov 2025 10:34:03 -0500 Subject: [PATCH] WIP: build ok --- ggml/src/ggml-cuda/conv3d-implicit.cu | 355 ++++++++++++++++--------- ggml/src/ggml-cuda/conv3d-implicit.cuh | 21 +- 2 files changed, 236 insertions(+), 140 deletions(-) diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index c6aa7e0749..00aaa568af 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -62,6 +62,29 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } +template +__device__ int4 inputIndices(const uint kidx, param_t param) { + + const uint cur0 = fastdiv(kidx, + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + const uint cur0_res = fastmodulo(kidx, + layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + const uint cur1 = fastdiv(cur0_res, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + const uint cur1_res = fastmodulo(cur0_res, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + const uint cur2 = fastdiv(cur1_res, + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur3 = fastmodulo(cur1_res, + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur3 : cur0; + const uint curT = layout == 0 ? cur0 : cur1; + const uint curR = layout == 0 ? cur1 : cur2; + const uint curS = layout == 0 ? cur2 : cur3; + return make_int4(curC, curT, curR, curS); + +} + template 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; const uint start_k = (ksplit > 0)? z * ks: 0; @@ -158,31 +182,44 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; + const uint inKOffset = start_k + innerColA * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z; - const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ; - const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2; - const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv); + const unsigned int gemm_i = bx * BM + innerRowA + offset; + // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z; + int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; + const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); + const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv); const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; - int inOffset = n * param.c * param.h * param.w * param.d; + int inOffset = n * inNOffset; if(vec_load){ - const uint cur0 = fastdiv(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curD = posd_ori + curT * param.dilation2; // input w - const int curH = posh_ori + curR * param.dilation1; // input h - const int curW = posw_ori + curS * param.dilation0; // input w - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){ + // const uint cur0 = fastdiv(inKOffset, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur0_res = fastmodulo(inKOffset, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur1 = fastdiv(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur1_res = fastmodulo(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur2 = fastdiv(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur3 = fastmodulo(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur3 : cur0; + // const uint curT = layout == 0 ? cur0 : cur1; + // const uint curR = layout == 0 ? cur1 : cur2; + // const uint curS = layout == 0 ? cur2 : cur3; + const int4 curIdx = inputIndices(inKOffset, param); + // const int curD = posd_ori + curT * param.dilation2; // input w + // const int curH = posh_ori + curR * param.dilation1; // input h + // const int curW = posw_ori + curS * param.dilation0; // input w + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset < end_k){ int inOffsetTmp = layout == 0 ? curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; @@ -199,23 +236,47 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ + // const uint cur0 = fastdiv(inKOffset + i, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur0_res = fastmodulo(inKOffset + i, + // layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset + // const uint cur1 = fastdiv(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur1_res = fastmodulo(cur0_res, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset + // const uint cur2 = fastdiv(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur3 = fastmodulo(cur1_res, + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur3 : cur0; + // const uint curT = layout == 0 ? cur0 : cur1; + // const uint curR = layout == 0 ? cur1 : cur2; + // const uint curS = layout == 0 ? cur2 : cur3; + const int4 curIdx = inputIndices(inKOffset + i, param); + // const int curD = posd_ori + curT * param.dilation2; // input w + // const int curH = posh_ori + curR * param.dilation1; // input h + // const int curW = posw_ori + curS * param.dilation0; // input w + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + // const uint cur0 = fastdiv(start_k + innerColA * 4 + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + // const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur2 : cur0; + // const uint curR = layout == 0 ? cur0 : cur1; + // const uint curS = layout == 0 ? cur1 : cur2; + // const int curH = posh_ori + curR * param.d_h; // input h + // const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; } else { smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; @@ -242,6 +303,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; + // main block k loop for (int crs = start_k; crs < end_k; crs += BK) { int load_flag = write_flag ^ 1; @@ -317,33 +379,50 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } } } + const uint inKkOffset = innerColA * 4 + crs + BK; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; + // int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + // const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + // const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + // const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + // int inOffset = n * param.c * param.h * param.w ; + const unsigned int gemm_i = bx * BM + innerRowA + offset; + int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; + const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv); + const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv); + const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; + const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; + int inOffset = n * inNOffset; if(vec_load){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; + const int4 curIdx = inputIndices(inKkOffset, param); + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + // const uint cur0 = fastdiv(innerColA * 4 + crs + BK, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur2 : cur0; + // const uint curR = layout == 0 ? cur0 : cur1; + // const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ - // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // const int curH = posh_ori + curR * param.d_h; // input h + // const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){ + // int inOffsetTmp = layout == 0 ? + // curH * inChannelOffset + curW * param.c + curC: + // curC * inChannelOffset + curH * param.w + curW; float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y; @@ -357,24 +436,33 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; + // const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + // const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, + // layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + // layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + // const uint curC = layout == 0 ? cur2 : cur0; + // const uint curR = layout == 0 ? cur0 : cur1; + // const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ + // const int curH = posh_ori + curR * param.d_h; // input h + // const int curW = posw_ori + curS * param.d_w; // input w + const int4 curIdx = inputIndices(inKkOffset + i, param); + const int curD = posd_ori + curIdx.y * param.dilation2; // input w + const int curH = posh_ori + curIdx.z * param.dilation1; // input h + const int curW = posw_ori + curIdx.w * param.dilation0; // input w + const int curC = curIdx.x; + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ + // int inOffsetTmp = layout == 0 ? + // curH * inChannelOffset + curW * param.c + curC: + // curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; } else { smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; @@ -448,15 +536,16 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, __syncthreads(); #pragma unroll for (int subk = 0; subk < TM * TN; ++subk){ + // output: [N*OC, OD, OH, OW] const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; - const int n = (ksplit > 0) ? gemm_i / PQ : z; - const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; - if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + const int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z; + const int col = (ksplit > 0) ? fastmodulo(gemm_i, param.PQZ_fastdiv) : gemm_i; + if (n < param.n && row < param.k && col < PQZ){ const uint outOffset = ksplit > 0 ? - z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + - row * param.Oh * param.Ow + col : - z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + // z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col : + ((z * param.n + n) * param.k + row) * PQZ + col : + (z * param.k + row) * PQZ + col; output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; } } @@ -464,7 +553,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, } } - +#if 0 template __device__ __forceinline__ void ldmatrix_a( @@ -885,6 +974,7 @@ constexpr unsigned int MMA_N = 8; #endif } +#endif #define NUM_VARIANTS 4 @@ -925,70 +1015,70 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, dim3 thblock(NUM_THREADS, thready, threadz); dim3 grid(blockx, blocky, blockz); - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { + // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { - int id = ggml_cuda_get_device(); + // int id = ggml_cuda_get_device(); - int64_t ne = P.c * P.h * P.w * P.n; - int64_t ne00 = P.c; - int64_t ne01 = P.h * P.w; - ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); + // int64_t ne = P.c * P.h * P.w * P.n; + // int64_t ne00 = P.c; + // int64_t ne01 = P.h * P.w; + // ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); - dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); - NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); + // dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + // dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); + // NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); - ne = P.c * P.r * P.s * P.k; - ne01 = P.r * P.s; - ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); - dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // ne = P.c * P.r * P.s * P.k; + // ne01 = P.r * P.s; + // ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + // dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - const half *X_H = input_f16.get(); - const half *K_H = kernel_f16.get(); - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + // const half *X_H = input_f16.get(); + // const half *K_H = kernel_f16.get(); + // ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); - constexpr unsigned int BM_dim = 256; - constexpr unsigned int BN_dim = 256; - constexpr unsigned int BK_dim = 32; + // constexpr unsigned int BM_dim = 256; + // constexpr unsigned int BN_dim = 256; + // constexpr unsigned int BK_dim = 32; - constexpr unsigned int WARPS_PER_BLOCK_M = 2; - constexpr unsigned int WARPS_PER_BLOCK_N = 4; - constexpr unsigned int WARPS_PER_BLOCK_K = 4; + // constexpr unsigned int WARPS_PER_BLOCK_M = 2; + // constexpr unsigned int WARPS_PER_BLOCK_N = 4; + // constexpr unsigned int WARPS_PER_BLOCK_K = 4; - constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; - constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; - constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; - const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; - const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; - constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; - constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; - constexpr unsigned int NumThreads = ThreadsM * ThreadsN; - const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + // constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; + // constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; + // constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + // const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; + // const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; + // constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; + // constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; + // constexpr unsigned int NumThreads = ThreadsM * ThreadsN; + // const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - cudaFuncSetAttribute(conv3d_implicit_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 - dim3 gridDim(BlocksN, BlocksM); - dim3 blockDim(ThreadsN, ThreadsM); + // cudaFuncSetAttribute(conv3d_implicit_kernel, + // cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + // dim3 gridDim(BlocksN, BlocksM); + // dim3 blockDim(ThreadsN, ThreadsM); - conv3d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); - } else{ + // conv3d_implicit_kernel + // <<>>(X_H, K_H, Y_H.get(), P); + // const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + // to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + // } else{ conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); - } + // } } @@ -1056,7 +1146,10 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * init_fastdiv_values(IC), init_fastdiv_values(KW*KH), init_fastdiv_values(KW), - init_fastdiv_values(OW*OH)}; + init_fastdiv_values(OW*OH), + init_fastdiv_values(OW*OH*OD), + init_fastdiv_values(KW*KH*IC), + init_fastdiv_values(KW*KH*KD)}; if (kernel->type == GGML_TYPE_F16) { conv3d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index d550ec07c8..4e14c15cd2 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -29,9 +29,13 @@ typedef struct{ uint3 RS_fastdiv; uint3 S_fastdiv; uint3 OHOW_fastdiv; + uint3 PQZ_fastdiv; + uint3 RSC_fastdiv; + uint3 TRS_fastdiv; } param_t; + // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template @@ -131,14 +135,14 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; unsigned int inOffset = n * param.c * param.h * param.w; const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w + int curH = posh_ori + curR * param.dilation1; // input h + int curW = posw_ori + curS * param.dilation0; // input w // apply swizzle to the dst index unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); @@ -197,14 +201,14 @@ __device__ __forceinline__ void tileMemcpyLoadA( unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0; unsigned int inOffset = n * param.c * param.h * param.w; const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w + int curH = posh_ori + curR * param.dilation1; // input h + int curW = posw_ori + curS * param.dilation0; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curR < param.r && curS < param.s && curC < param.c){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; @@ -348,6 +352,5 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { return address; } - #define CUDA_CONV3D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);