WIP: updating indices for input and kernel; enable OP_CONV_3D for cuda backend
This commit is contained in:
parent
ab15f6cd5f
commit
52455b8a6d
|
|
@ -80,7 +80,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
||||||
const uint bx = blockIdx.x;
|
const uint bx = blockIdx.x;
|
||||||
const uint by = blockIdx.y;
|
const uint by = blockIdx.y;
|
||||||
|
|
||||||
const uint PQ = param.Oh * param.Ow;
|
const uint PQZ = param.Oh * param.Ow * param.Oz;
|
||||||
|
|
||||||
// Warp tile
|
// Warp tile
|
||||||
const uint lane_id = tx % WARPSIZE;
|
const uint lane_id = tx % WARPSIZE;
|
||||||
|
|
@ -100,7 +100,8 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
||||||
int z = blockIdx.z;
|
int z = blockIdx.z;
|
||||||
|
|
||||||
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
|
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
|
||||||
int weightKOffset = param.c * param.r * param.s;
|
int inDepthOffset = layout == 0 ? param.h * param.c * param.w : param.d * param.h * param.w;
|
||||||
|
int weightKOffset = param.c * param.r * param.s * param.t;
|
||||||
|
|
||||||
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
||||||
const uint start_k = (ksplit > 0)? z * ks: 0;
|
const uint start_k = (ksplit > 0)? z * ks: 0;
|
||||||
|
|
@ -159,11 +160,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
||||||
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
|
||||||
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ;
|
||||||
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
|
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2;
|
||||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv);
|
||||||
int inOffset = n * param.c * param.h * param.w ;
|
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||||
|
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||||
|
int inOffset = n * param.c * param.h * param.w * param.d;
|
||||||
if(vec_load){
|
if(vec_load){
|
||||||
const uint cur0 = fastdiv(start_k + innerColA * 4,
|
const uint cur0 = fastdiv(start_k + innerColA * 4,
|
||||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||||
|
|
@ -176,12 +179,13 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
||||||
const uint curC = layout == 0 ? cur2 : cur0;
|
const uint curC = layout == 0 ? cur2 : cur0;
|
||||||
const uint curR = layout == 0 ? cur0 : cur1;
|
const uint curR = layout == 0 ? cur0 : cur1;
|
||||||
const uint curS = layout == 0 ? cur1 : cur2;
|
const uint curS = layout == 0 ? cur1 : cur2;
|
||||||
const int curH = posh_ori + curR * param.d_h; // input h
|
const int curD = posd_ori + curT * param.dilation2; // input w
|
||||||
const int curW = posw_ori + curS * param.d_w; // input w
|
const int curH = posh_ori + curR * param.dilation1; // input h
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
|
const int curW = posw_ori + curS * param.dilation0; // input w
|
||||||
|
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){
|
||||||
int inOffsetTmp = layout == 0 ?
|
int inOffsetTmp = layout == 0 ?
|
||||||
curH * inChannelOffset + curW * param.c + curC:
|
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
|
||||||
curC * inChannelOffset + curH * param.w + curW;
|
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
|
||||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||||
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
||||||
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
|
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
#include "ggml-cuda/concat.cuh"
|
#include "ggml-cuda/concat.cuh"
|
||||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||||
#include "ggml-cuda/conv2d.cuh"
|
#include "ggml-cuda/conv2d.cuh"
|
||||||
|
#include "ggml-cuda/conv3d-implicit.cuh"
|
||||||
#include "ggml-cuda/conv2d-dw.cuh"
|
#include "ggml-cuda/conv2d-dw.cuh"
|
||||||
#include "ggml-cuda/conv2d-transpose.cuh"
|
#include "ggml-cuda/conv2d-transpose.cuh"
|
||||||
#include "ggml-cuda/convert.cuh"
|
#include "ggml-cuda/convert.cuh"
|
||||||
|
|
@ -2629,6 +2630,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
ggml_cuda_op_conv2d(ctx, dst);
|
ggml_cuda_op_conv2d(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_CONV_3D:
|
||||||
|
ggml_cuda_op_conv3d_implicit(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
ggml_cuda_op_conv2d_dw(ctx, dst);
|
ggml_cuda_op_conv2d_dw(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
|
@ -4041,6 +4045,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_IM2COL_3D:
|
case GGML_OP_IM2COL_3D:
|
||||||
case GGML_OP_CONV_2D:
|
case GGML_OP_CONV_2D:
|
||||||
|
case GGML_OP_CONV_3D:
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue