get rid of a convert unary kernel call and fuse the type cast into conv epilogue

This commit is contained in:
bssrdf 2025-11-10 12:39:50 -05:00
parent 1fdcb05dc8
commit a660d4d45d
2 changed files with 18 additions and 20 deletions

View File

@ -582,11 +582,11 @@ __device__ __forceinline__ void ldmatrix_b(
#endif
}
template<const int BM, const int BN, const int BK, const int WM, const int WN,
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
const int WK, const int ksplit, const int NUM_THREADS>
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const half * __restrict__ kernel,
half * __restrict__ output,
T * __restrict__ output,
const param_t param) {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@ -763,10 +763,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const uint outOffset = z * NKPQ +
n * KPQ +
row * PQ + col;
output[outOffset] = res_[0];
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
} else {
const uint outOffset = n * KPQ + row * PQ + col;
output[outOffset] = res_[0];
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
}
}
if (n < param.n && row+1 < param.k && col < PQ) {
@ -774,10 +774,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const uint outOffset = z * NKPQ +
n * KPQ +
(row+1) * PQ + col;
output[outOffset] = res_[1];
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
} else {
const uint outOffset = n * KPQ + (row+1) * PQ + col;
output[outOffset] = res_[1];
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
}
}
}
@ -848,12 +848,12 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx,
int id = ggml_cuda_get_device();
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
cudaFuncSetAttribute(conv2d_implicit_kernel<BM, BN, BK, WM, WN, WK, ksplit, NUM_THREADS>,
cudaFuncSetAttribute(conv2d_implicit_kernel<half, BM, BN, BK, WM, WN, WK, ksplit, NUM_THREADS>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM, ksplit);
dim3 blockDim(ThreadsN, ThreadsM);
conv2d_implicit_kernel<BM, BN, BK,
conv2d_implicit_kernel<half, BM, BN, BK,
WM, WN, WK, ksplit, NUM_THREADS><<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow;
@ -866,7 +866,7 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx,
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
// if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) {
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) {
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0) {
int id = ggml_cuda_get_device();
@ -883,8 +883,10 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
ne = P.c * P.r * P.s * P.k;
ne01 = P.r * P.s;
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
// ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id));
if (ne01 > 1){
kernel_f16.alloc(ne);
dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM,
1) ;
@ -973,7 +975,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
if(candidate != -1){
j = candidate;
// printf(" choosing %d, %d \n", j, max_remaining_waves);
if (j == 2) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 2,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
@ -1023,18 +1024,15 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
return;
}
}
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
cudaFuncSetAttribute(conv2d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
conv2d_implicit_kernel<float, BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, 0, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_D, P);
} else{
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
}

View File

@ -299,7 +299,7 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(640,640,52,76,3,3),
// std::make_tuple(640,640,104,152,3,3),
// std::make_tuple(960,320,104,152,3,3),
std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(1920,640,32,32,3,3)
// std::make_tuple(1280,1280,16,16,3,3),
// std::make_tuple(320,640,32,32,3,3),
@ -317,7 +317,7 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(1920,1280,26,38,3,3),
// std::make_tuple(2560,1280,26,38,3,3),
// std::make_tuple(320,1280,26,38,3,3),
// std::make_tuple(512,512,104,152,3,3),
std::make_tuple(512,512,104,152,3,3),
// std::make_tuple(512,512,208,304,3,3),
// std::make_tuple(512,256,416,608,3,3),
// std::make_tuple(256,128,832,1216,3,3),
@ -714,7 +714,7 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// // for(int i = 0; i < conv2d_data.size(); i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",