WIP: build ok

This commit is contained in:
bssrdf 2025-11-02 10:34:03 -05:00
parent 52455b8a6d
commit 0a64ea8ff8
2 changed files with 236 additions and 140 deletions

View File

@ -62,6 +62,29 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
}
}
template<const int layout>
__device__ int4 inputIndices(const uint kidx, param_t param) {
const uint cur0 = fastdiv(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
const uint cur0_res = fastmodulo(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
const uint cur1 = fastdiv(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
const uint cur1_res = fastmodulo(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
const uint cur2 = fastdiv(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur3 = fastmodulo(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur3 : cur0;
const uint curT = layout == 0 ? cur0 : cur1;
const uint curR = layout == 0 ? cur1 : cur2;
const uint curS = layout == 0 ? cur2 : cur3;
return make_int4(curC, curT, curR, curS);
}
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
// layout: 0, NHWC; 1, NCHW
@ -80,7 +103,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const uint bx = blockIdx.x;
const uint by = blockIdx.y;
const uint PQZ = param.Oh * param.Ow * param.Oz;
const uint PQZ = param.Oh * param.Ow * param.Od;
// Warp tile
const uint lane_id = tx % WARPSIZE;
@ -102,6 +125,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
int inDepthOffset = layout == 0 ? param.h * param.c * param.w : param.d * param.h * param.w;
int weightKOffset = param.c * param.r * param.s * param.t;
int inNOffset = param.c * param.w * param.h * param.d;
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
const uint start_k = (ksplit > 0)? z * ks: 0;
@ -158,31 +182,44 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
const uint inKOffset = start_k + innerColA * 4;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ;
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv);
const unsigned int gemm_i = bx * BM + innerRowA + offset;
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
int inOffset = n * param.c * param.h * param.w * param.d;
int inOffset = n * inNOffset;
if(vec_load){
const uint cur0 = fastdiv(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curD = posd_ori + curT * param.dilation2; // input w
const int curH = posh_ori + curR * param.dilation1; // input h
const int curW = posw_ori + curS * param.dilation0; // input w
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){
// const uint cur0 = fastdiv(inKOffset,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur0_res = fastmodulo(inKOffset,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur1 = fastdiv(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur1_res = fastmodulo(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur2 = fastdiv(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur3 = fastmodulo(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur3 : cur0;
// const uint curT = layout == 0 ? cur0 : cur1;
// const uint curR = layout == 0 ? cur1 : cur2;
// const uint curS = layout == 0 ? cur2 : cur3;
const int4 curIdx = inputIndices<layout>(inKOffset, param);
// const int curD = posd_ori + curT * param.dilation2; // input w
// const int curH = posh_ori + curR * param.dilation1; // input h
// const int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset < end_k){
int inOffsetTmp = layout == 0 ?
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
@ -199,23 +236,47 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
// const uint cur0 = fastdiv(inKOffset + i,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur0_res = fastmodulo(inKOffset + i,
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
// const uint cur1 = fastdiv(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur1_res = fastmodulo(cur0_res,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
// const uint cur2 = fastdiv(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur3 = fastmodulo(cur1_res,
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur3 : cur0;
// const uint curT = layout == 0 ? cur0 : cur1;
// const uint curR = layout == 0 ? cur1 : cur2;
// const uint curS = layout == 0 ? cur2 : cur3;
const int4 curIdx = inputIndices<layout>(inKOffset + i, param);
// const int curD = posd_ori + curT * param.dilation2; // input w
// const int curH = posh_ori + curR * param.dilation1; // input h
// const int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
// const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur2 : cur0;
// const uint curR = layout == 0 ? cur0 : cur1;
// const uint curS = layout == 0 ? cur1 : cur2;
// const int curH = posh_ori + curR * param.d_h; // input h
// const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
} else {
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
@ -242,6 +303,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN +
threadColInWarp * TN + i];
// main block k loop
for (int crs = start_k; crs < end_k; crs += BK) {
int load_flag = write_flag ^ 1;
@ -317,33 +379,50 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
}
}
}
const uint inKkOffset = innerColA * 4 + crs + BK;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
int inOffset = n * param.c * param.h * param.w ;
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
// const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
// const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
// const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
// int inOffset = n * param.c * param.h * param.w ;
const unsigned int gemm_i = bx * BM + innerRowA + offset;
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
int inOffset = n * inNOffset;
if(vec_load){
const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int4 curIdx = inputIndices<layout>(inKkOffset, param);
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur2 : cur0;
// const uint curR = layout == 0 ? cur0 : cur1;
// const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
// const int curH = posh_ori + curR * param.d_h; // input h
// const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){
// int inOffsetTmp = layout == 0 ?
// curH * inChannelOffset + curW * param.c + curC:
// curC * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x;
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y;
@ -357,24 +436,33 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
// const uint curC = layout == 0 ? cur2 : cur0;
// const uint curR = layout == 0 ? cur0 : cur1;
// const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
// const int curH = posh_ori + curR * param.d_h; // input h
// const int curW = posw_ori + curS * param.d_w; // input w
const int4 curIdx = inputIndices<layout>(inKkOffset + i, param);
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
// int inOffsetTmp = layout == 0 ?
// curH * inChannelOffset + curW * param.c + curC:
// curC * inChannelOffset + curH * param.w + curW;
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
} else {
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
@ -448,15 +536,16 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
__syncthreads();
#pragma unroll
for (int subk = 0; subk < TM * TN; ++subk){
// output: [N*OC, OD, OH, OW]
const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM;
const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM;
const int n = (ksplit > 0) ? gemm_i / PQ : z;
const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i;
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
const int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
const int col = (ksplit > 0) ? fastmodulo(gemm_i, param.PQZ_fastdiv) : gemm_i;
if (n < param.n && row < param.k && col < PQZ){
const uint outOffset = ksplit > 0 ?
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
row * param.Oh * param.Ow + col :
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
// z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col :
((z * param.n + n) * param.k + row) * PQZ + col :
(z * param.k + row) * PQZ + col;
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
}
}
@ -464,7 +553,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
}
}
#if 0
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
__device__ __forceinline__ void ldmatrix_a(
@ -885,6 +974,7 @@ constexpr unsigned int MMA_N = 8;
#endif
}
#endif
#define NUM_VARIANTS 4
@ -925,70 +1015,70 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
dim3 thblock(NUM_THREADS, thready, threadz);
dim3 grid(blockx, blocky, blockz);
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
conv3d_implicit_kernel<T, BM, BN, BK, WM, WN,
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
}
static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
// if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
int id = ggml_cuda_get_device();
// int id = ggml_cuda_get_device();
int64_t ne = P.c * P.h * P.w * P.n;
int64_t ne00 = P.c;
int64_t ne01 = P.h * P.w;
ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
// int64_t ne = P.c * P.h * P.w * P.n;
// int64_t ne00 = P.c;
// int64_t ne01 = P.h * P.w;
// ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
// dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
// dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
// NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
ne = P.c * P.r * P.s * P.k;
ne01 = P.r * P.s;
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
// ne = P.c * P.r * P.s * P.k;
// ne01 = P.r * P.s;
// ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
// dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
// NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
const half *X_H = input_f16.get();
const half *K_H = kernel_f16.get();
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
// const half *X_H = input_f16.get();
// const half *K_H = kernel_f16.get();
// ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
constexpr unsigned int BM_dim = 256;
constexpr unsigned int BN_dim = 256;
constexpr unsigned int BK_dim = 32;
// constexpr unsigned int BM_dim = 256;
// constexpr unsigned int BN_dim = 256;
// constexpr unsigned int BK_dim = 32;
constexpr unsigned int WARPS_PER_BLOCK_M = 2;
constexpr unsigned int WARPS_PER_BLOCK_N = 4;
constexpr unsigned int WARPS_PER_BLOCK_K = 4;
// constexpr unsigned int WARPS_PER_BLOCK_M = 2;
// constexpr unsigned int WARPS_PER_BLOCK_N = 4;
// constexpr unsigned int WARPS_PER_BLOCK_K = 4;
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
// constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
// constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
// constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
// const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
// const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
// constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
// constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
// constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
// const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
// cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
// cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
// dim3 gridDim(BlocksN, BlocksM);
// dim3 blockDim(ThreadsN, ThreadsM);
conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
} else{
// conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
// WM_dim, WN_dim, WK_dim, NumThreads>
// <<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
// const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
// to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
// } else{
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
}
// }
}
@ -1056,7 +1146,10 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
init_fastdiv_values(IC),
init_fastdiv_values(KW*KH),
init_fastdiv_values(KW),
init_fastdiv_values(OW*OH)};
init_fastdiv_values(OW*OH),
init_fastdiv_values(OW*OH*OD),
init_fastdiv_values(KW*KH*IC),
init_fastdiv_values(KW*KH*KD)};
if (kernel->type == GGML_TYPE_F16) {
conv3d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st);

View File

@ -29,9 +29,13 @@ typedef struct{
uint3 RS_fastdiv;
uint3 S_fastdiv;
uint3 OHOW_fastdiv;
uint3 PQZ_fastdiv;
uint3 RSC_fastdiv;
uint3 TRS_fastdiv;
} param_t;
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
@ -131,14 +135,14 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
int curH = posh_ori + curR * param.dilation1; // input h
int curW = posw_ori + curS * param.dilation0; // input w
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
@ -197,14 +201,14 @@ __device__ __forceinline__ void tileMemcpyLoadA(
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
int curH = posh_ori + curR * param.dilation1; // input h
int curW = posw_ori + curS * param.dilation0; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
@ -348,6 +352,5 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
return address;
}
#define CUDA_CONV3D_IMPLICT_BLOCK_SIZE 256
void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);