WIP: updating indices for input and kernel; enable OP_CONV_3D for cuda backend
This commit is contained in:
parent
ab15f6cd5f
commit
52455b8a6d
|
|
@ -80,7 +80,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
const uint bx = blockIdx.x;
|
||||
const uint by = blockIdx.y;
|
||||
|
||||
const uint PQ = param.Oh * param.Ow;
|
||||
const uint PQZ = param.Oh * param.Ow * param.Oz;
|
||||
|
||||
// Warp tile
|
||||
const uint lane_id = tx % WARPSIZE;
|
||||
|
|
@ -100,7 +100,8 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
int z = blockIdx.z;
|
||||
|
||||
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
|
||||
int weightKOffset = param.c * param.r * param.s;
|
||||
int inDepthOffset = layout == 0 ? param.h * param.c * param.w : param.d * param.h * param.w;
|
||||
int weightKOffset = param.c * param.r * param.s * param.t;
|
||||
|
||||
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
||||
const uint start_k = (ksplit > 0)? z * ks: 0;
|
||||
|
|
@ -159,11 +160,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
||||
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
||||
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
|
||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
||||
int inOffset = n * param.c * param.h * param.w ;
|
||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
|
||||
const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ;
|
||||
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2;
|
||||
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv);
|
||||
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
int inOffset = n * param.c * param.h * param.w * param.d;
|
||||
if(vec_load){
|
||||
const uint cur0 = fastdiv(start_k + innerColA * 4,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
|
|
@ -176,12 +179,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
|
||||
const int curD = posd_ori + curT * param.dilation2; // input w
|
||||
const int curH = posh_ori + curR * param.dilation1; // input h
|
||||
const int curW = posw_ori + curS * param.dilation0; // input w
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/conv2d.cuh"
|
||||
#include "ggml-cuda/conv3d-implicit.cuh"
|
||||
#include "ggml-cuda/conv2d-dw.cuh"
|
||||
#include "ggml-cuda/conv2d-transpose.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
|
|
@ -2629,6 +2630,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_CONV_2D:
|
||||
ggml_cuda_op_conv2d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
ggml_cuda_op_conv3d_implicit(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_cuda_op_conv2d_dw(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4041,6 +4045,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_3D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
|
|
|
|||
Loading…
Reference in New Issue