WIP: updating indices for input and kernel; enable OP_CONV_3D for cuda backend

This commit is contained in:
bssrdf 2025-11-01 22:01:00 -04:00
parent ab15f6cd5f
commit 52455b8a6d
2 changed files with 21 additions and 12 deletions

View File

@ -80,7 +80,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const uint bx = blockIdx.x;
const uint by = blockIdx.y;
const uint PQ = param.Oh * param.Ow;
const uint PQZ = param.Oh * param.Ow * param.Oz;
// Warp tile
const uint lane_id = tx % WARPSIZE;
@ -100,7 +100,8 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
int z = blockIdx.z;
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
int weightKOffset = param.c * param.r * param.s;
int inDepthOffset = layout == 0 ? param.h * param.c * param.w : param.d * param.h * param.w;
int weightKOffset = param.c * param.r * param.s * param.t;
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
const uint start_k = (ksplit > 0)? z * ks: 0;
@ -159,11 +160,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
int inOffset = n * param.c * param.h * param.w ;
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ;
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
int inOffset = n * param.c * param.h * param.w * param.d;
if(vec_load){
const uint cur0 = fastdiv(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
@ -176,12 +179,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
const int curD = posd_ori + curT * param.dilation2; // input w
const int curH = posh_ori + curR * param.dilation1; // input h
const int curW = posw_ori + curS * param.dilation0; // input w
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[input_sts_addr + offset + 0] = tmp.x;
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;

View File

@ -13,6 +13,7 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv3d-implicit.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
@ -2629,6 +2630,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
case GGML_OP_CONV_3D:
ggml_cuda_op_conv3d_implicit(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
@ -4041,6 +4045,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D: