added block variants; to be debugged

This commit is contained in:
bssrdf 2025-10-14 11:02:10 -04:00
parent 16b0f0ae3c
commit 2237722056
3 changed files with 131 additions and 56 deletions

View File

@ -1,8 +1,11 @@
#include "conv2d-implicit.cuh"
// #include <cuda_runtime.h>
#include "ggml.h"
#include "common.cuh"
#include "convert.cuh"
#include "conv2d-implicit.cuh"
static const int WARPSIZE = 32; // warpSize is not constexpr
typedef unsigned int uint;
static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.x;
@ -20,7 +23,8 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
const bool vec_load, const int ksplit, const int PAD=4>
// layout: 0, NHWC; 1, NCHW
const int layout, const bool vec_load, const int ksplit, const int PAD=4>
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
@ -76,7 +80,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
// int inOffset = (ksplit > 0): z * param.c * param.h * param.w ;
// int weiOffset = (by * BN + tx / 8 * 4) * param.c * param.r * param.s;
int inChannelOffset = param.c * param.w;
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
// int weightChannelOffset = param.r * param.s;
int weightKOffset = param.c * param.r * param.s;
@ -125,16 +129,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
if(vec_load){
// if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){
if constexpr (std::is_same_v(T, float)){
float4 tmp = reinterpret_cast<float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
if constexpr (std::is_same_v<T, float>){
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
smemweight[weight_sts_addr + offset + 0] = tmp.x;
smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y;
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
}else{ // read 4 halves
// half val[4];
float2 tmp = reinterpret_cast<float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
half *val = reinterpret_cast<half *>(&tmp);
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
const half *val = reinterpret_cast<const half *>(&tmp);
// val[1] = reinterpret_cast<half2 *>(&tmp.y);
smemweight[weight_sts_addr + offset + 0] = val[0];
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
@ -177,15 +181,28 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
int inOffset = n * param.c * param.h * param.w ;
if(vec_load){
const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset
const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset
// const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint cur0 = fastdiv(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
float4 tmp = reinterpret_cast<float4 *>(&input[inOffset + inOffsetTmp])[0];
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[input_sts_addr + offset + 0] = tmp.x;
smeminput[input_sts_addr + offset + BM] = tmp.y;
smeminput[input_sts_addr + offset + 2*BM] = tmp.z;
@ -198,14 +215,27 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset
const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset
// const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
smeminput[input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp];
} else {
smeminput[input_sts_addr + offset + i*BM] = 0.f;
@ -398,15 +428,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
if(vec_load){
if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){
if constexpr (std::is_same_v(T, float)){
float4 tmp = reinterpret_cast<float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
if constexpr (std::is_same_v<T, float>){
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x;
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y;
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
} else {
float2 tmp = reinterpret_cast<float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
half *val = reinterpret_cast<half *>(&tmp);
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
const half *val = reinterpret_cast<const half *>(&tmp);
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1];
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
@ -437,15 +467,29 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
int inOffset = n * param.c * param.h * param.w ;
if(vec_load){
const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset
const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset
// const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){
int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
float4 tmp = reinterpret_cast<float4 *>(&input[inOffset + inOffsetTmp])[0];
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
smeminput[write_flag * BM * BK + input_sts_addr + offset + 0] = tmp.x;
smeminput[write_flag * BM * BK + input_sts_addr + offset + BM] = tmp.y;
smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z;
@ -458,14 +502,28 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
} else {
#pragma unroll
for (int i = 0; i < 4; ++i){
const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset
const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset
// const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
curH * inChannelOffset + curW * param.c + curC:
curC * inChannelOffset + curH * param.w + curW;
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp];
} else {
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f;
@ -684,26 +742,37 @@ constexpr static int conv_shapes[][NUM_VARIANTS] = {
{ 256, 256, 128} // NUM_THREADS
};
template <typename T>
template <typename T, unsigned int CONV_SHAPE>
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number
int blocky = (P.k + 127) / 128; // blocky number
const uint BM = conv_shapes[0][CONV_SHAPE];
const uint BN = conv_shapes[1][CONV_SHAPE];
const uint BK = conv_shapes[2][CONV_SHAPE];
const uint WM = conv_shapes[3][CONV_SHAPE];
const uint WN = conv_shapes[4][CONV_SHAPE];
const uint WNITER = conv_shapes[5][CONV_SHAPE];
const uint TM = conv_shapes[6][CONV_SHAPE];
const uint TN = conv_shapes[7][CONV_SHAPE];
const uint NUM_THREADS = conv_shapes[8][CONV_SHAPE];
int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number
int blocky = (P.k + BN-1) / BN; // blocky number
int blockz = P.n; // blockz number
int threadx = CUDA_CONV2D_IMPLICT_BLOCK_SIZE; // threadx number per block
// int threadx = NUM; // threadx number per block
int thready = 1; // thready number per block
int threadz = 1; // threadz number per block
dim3 thblock(threadx, thready, threadz);
dim3 thblock(NUM_THREADS, thready, threadz);
dim3 grid(blockx, blocky, blockz);
int smem_size = 24 * 1024;
conv2d_implicit_kernel<T><<<grid, thblock, smem_size, st>>>(X_D, K_D, Y_D, P);
// int smem_size = 24 * 1024;
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
}
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<half>(X_D, K_D, Y_D, P, st);
conv2d_implicit_cuda<half, 0>(X_D, K_D, Y_D, P, st);
}
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<float>(X_D, K_D, Y_D, P, st);
conv2d_implicit_cuda<float, 0>(X_D, K_D, Y_D, P, st);
}
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -745,9 +814,12 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
const int64_t total = B * OC * OH * OW;
param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW };
params.SC_fastdiv = init_fastdiv_values(KW*KH);
params.SC_fastdiv = init_fastdiv_values(KW*IC);
params.OW_fastdiv = init_fastdiv_values(OW);
params.C_fastdiv = init_fastdiv_values(IC);
params.RS_fastdiv = init_fastdiv_values(KW*KH);
params.S_fastdiv = init_fastdiv_values(KW);
params.nchw = false;
if (kernel->type == GGML_TYPE_F16) {
conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st);

View File

@ -17,9 +17,12 @@ typedef struct{
unsigned int d_w; //dilation width
unsigned int Oh; //output height
unsigned int Ow; //output width
bool nchw;
uint3 SC_fastdiv;
uint3 OW_fastdiv;
uint3 C_fastdiv;
uint3 RS_fastdiv;
uint3 S_fastdiv;
} param_t;

View File

@ -339,21 +339,21 @@ int main(void)
{
ggml_time_init();
std::vector<std::tuple<int, int, int, int>> configs = {
std::make_tuple(64,64,48,64),
std::make_tuple(320,320,104,152),
std::make_tuple(640,640,52,76),
std::make_tuple(640,640,104,152),
std::make_tuple(960,320,104,152),
std::make_tuple(1280,1280,26,38),
std::make_tuple(1280,640,52,76),
std::make_tuple(1920,1280,26,38),
std::make_tuple(2560,1280,26,38),
std::make_tuple(512,512,104,152),
std::make_tuple(512,512,208,304),
// std::make_tuple(64,64,48,64),
// std::make_tuple(320,320,104,152),
// std::make_tuple(640,640,52,76),
// std::make_tuple(640,640,104,152),
// std::make_tuple(960,320,104,152),
// std::make_tuple(1280,1280,26,38),
// std::make_tuple(1280,640,52,76),
// std::make_tuple(1920,1280,26,38),
// std::make_tuple(2560,1280,26,38),
// std::make_tuple(512,512,104,152),
// std::make_tuple(512,512,208,304),
std::make_tuple(512,256,416,608),
std::make_tuple(256,128,832,1216),
std::make_tuple(256,256,832,1216),
std::make_tuple(320,256,1024,1920)
// std::make_tuple(256,128,832,1216),
// std::make_tuple(256,256,832,1216),
// std::make_tuple(320,256,1024,1920)
};
int k = 0;