diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index 640366b80a..c6aa7e0749 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -80,7 +80,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint bx = blockIdx.x; const uint by = blockIdx.y; - const uint PQ = param.Oh * param.Ow; + const uint PQZ = param.Oh * param.Ow * param.Oz; // Warp tile const uint lane_id = tx % WARPSIZE; @@ -100,7 +100,8 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, int z = blockIdx.z; int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w; - int weightKOffset = param.c * param.r * param.s; + int inDepthOffset = layout == 0 ? param.h * param.c * param.w : param.d * param.h * param.w; + int weightKOffset = param.c * param.r * param.s * param.t; const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; const uint start_k = (ksplit > 0)? z * ks: 0; @@ -159,11 +160,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z; + const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ; + const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2; + const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv); + const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1; + const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0; + int inOffset = n * param.c * param.h * param.w * param.d; if(vec_load){ const uint cur0 = fastdiv(start_k + innerColA * 4, layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset @@ -176,12 +179,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input, const uint curC = layout == 0 ? cur2 : cur0; const uint curR = layout == 0 ? cur0 : cur1; const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ + const int curD = posd_ori + curT * param.dilation2; // input w + const int curH = posh_ori + curR * param.dilation1; // input h + const int curW = posw_ori + curS * param.dilation0; // input w + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){ int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; + curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC: + curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW; float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[input_sts_addr + offset + 0] = tmp.x; smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 61a8f1df87..7f16a3b826 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -13,6 +13,7 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" +#include "ggml-cuda/conv3d-implicit.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2629,6 +2630,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_2D: ggml_cuda_op_conv2d(ctx, dst); break; + case GGML_OP_CONV_3D: + ggml_cuda_op_conv3d_implicit(ctx, dst); + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -4041,6 +4045,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: