Merge branch 'refactor-cuda-core-path' into conv2d-implicit

This commit is contained in:
bssrdf 2025-11-06 11:05:06 -05:00
commit 28b7094750
2 changed files with 150 additions and 197 deletions

View File

@ -27,6 +27,7 @@ static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restri
}
}
template <typename src_T, typename dst_T>
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
@ -63,6 +64,8 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
}
}
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
// layout: 0, NHWC; 1, NCHW
@ -82,6 +85,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const uint by = blockIdx.y;
const uint PQ = param.Oh * param.Ow;
const uint CHW = param.c * param.h * param.w;
// Warp tile
const uint lane_id = tx % WARPSIZE;
@ -119,107 +123,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
// ldg
const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
if(vec_load){
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){
if constexpr (std::is_same_v<T, float>){
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
smemweight[weight_sts_addr + offset + 0] = tmp.x;
smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y;
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
}else{ // read 4 halves
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
const half *val = reinterpret_cast<const half *>(&tmp);
smemweight[weight_sts_addr + offset + 0] = val[0];
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
}
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
}
}
}else{
#pragma unroll
for (int i = 0; i < 4; ++i){
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i];
} else {
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
}
}
}
}
loadFilter<T, BN, rowStrideA, layout, vec_load, ksplit, PAD>
(kernel, smemweight, by, innerRowA, innerColA, weightKOffset,
start_k, end_k, param);
loadInput<BM, rowStrideA, layout, vec_load, ksplit, PAD>
(input, smeminput, bx, innerRowA, innerColA,
start_k, end_k, PQ, CHW, inChannelOffset, param);
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
int inOffset = n * param.c * param.h * param.w ;
if(vec_load){
const uint cur0 = fastdiv(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[input_sts_addr + offset + 0] = tmp.x;
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
} else {
#pragma unroll
for (int i = 0; i < 4; ++i)
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
}
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
} else {
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
}
}
}
}
__syncthreads();
// lds
@ -279,106 +190,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
}
}
// ldg
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
if(vec_load){
if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){
if constexpr (std::is_same_v<T, float>){
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x;
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y;
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
} else {
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
const half *val = reinterpret_cast<const half *>(&tmp);
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
}
} else {
#pragma unroll
for (int i = 0; i < 4; ++i)
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
}
}else{
#pragma unroll
for (int i = 0; i < 4; ++i){
if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){
// float4 tmp = reinterpret_cast<float4 *>(&param.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i];
} else {
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
}
}
}
}
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
int inOffset = n * param.c * param.h * param.w ;
if(vec_load){
const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x;
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y;
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
} else {
#pragma unroll
for (int i = 0; i < 4; ++i)
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
}
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
loadFilter<T, BN, rowStrideA, layout, vec_load, ksplit, PAD>
(kernel, &smemweight[write_flag * (BN+PAD) * BK], by, innerRowA, innerColA, weightKOffset,
crs+BK, end_k, param);
loadInput<BM, rowStrideA, layout, vec_load, ksplit, PAD>
(input, &smeminput[write_flag * (BM+PAD) * BK], bx, innerRowA, innerColA,
crs + BK, end_k, PQ, CHW, inChannelOffset, param);
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
} else {
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
}
}
}
}
__syncthreads();
write_flag ^= 1;

View File

@ -359,6 +359,139 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
return address;
}
template<typename T, const int BN, const int rowStrideA, const int layout,
const bool vec_load, const int ksplit, const int PAD>
__device__ __forceinline__ void loadFilter(const T * __restrict__ kernel,
T * __restrict__ smemweight,
const unsigned int by,
const unsigned int innerRowA,
const unsigned int innerColA,
const unsigned int weightKOffset,
const unsigned int start_k,
const unsigned int end_k,
const param_t param){
const unsigned int weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
const unsigned int kidx = start_k + innerColA * 4;
#pragma unroll
for (int offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
const unsigned int nidx = by * BN + innerRowA + offset;
if (vec_load) {
if (nidx < param.k && kidx < end_k) {
if constexpr (std::is_same_v<T, float>){
float4 tmp = reinterpret_cast<const float4 *>(&kernel[nidx * weightKOffset + kidx])[0];
smemweight[weight_sts_addr + offset + 0] = tmp.x;
smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y;
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
} else { // read 4 halves
float2 tmp = reinterpret_cast<const float2 *>(&kernel[nidx * weightKOffset + kidx])[0];
const half *val = reinterpret_cast<const half *>(&tmp);
smemweight[weight_sts_addr + offset + 0] = val[0];
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
}
} else {
#pragma unroll
for (int i = 0; i < 4; ++i) {
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
}
}
} else {
#pragma unroll
for (int i = 0; i < 4; ++i) {
if (nidx < param.k && kidx + i < end_k) {
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[nidx * weightKOffset + kidx + i];
} else {
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
}
}
}
}
}
template<const int BM, const int rowStrideA, const int layout,
const bool vec_load, const int ksplit, const int PAD>
__device__ __forceinline__ void loadInput(const float * __restrict__ input,
float * __restrict__ smeminput,
const unsigned int bx,
const unsigned int innerRowA,
const unsigned int innerColA,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int PQ,
const unsigned int CHW,
const unsigned int inChannelOffset,
const param_t param) {
const unsigned int input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
const unsigned int kidx = start_k + innerColA * 4;
#pragma unroll
for (unsigned int offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
const unsigned int midx = bx * BM + innerRowA + offset;
int n = (ksplit > 0) ? midx / PQ : blockIdx.z;
const unsigned int npq_res = midx % PQ;
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.u - param.p;
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.v - param.q;
const unsigned int inOffset = n * CHW;
if (vec_load) {
const unsigned int cur0 = fastdiv(kidx,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const unsigned int cur1 = fastdiv(fastmodulo(kidx,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const unsigned int cur2 = fastmodulo(fastmodulo(kidx,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const unsigned int curC = layout == 0 ? cur2 : cur0;
const unsigned int curR = layout == 0 ? cur0 : cur1;
const unsigned int curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx < end_k) {
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[input_sts_addr + offset + 0] = tmp.x;
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
} else {
#pragma unroll
for (int i = 0; i < 4; ++i)
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
}
} else {
#pragma unroll
for (int i = 0; i < 4; ++i) {
const unsigned int cur0 = fastdiv(kidx + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const unsigned int cur1 = fastdiv(fastmodulo(kidx + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const unsigned int cur2 = fastmodulo(fastmodulo(kidx + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const unsigned int curC = layout == 0 ? cur2 : cur0;
const unsigned int curR = layout == 0 ? cur0 : cur1;
const unsigned int curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx + i < end_k) {
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
} else {
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
}
}
}
}
}
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);