fuse cast to float into conv epilogue; improve swizzling for output

This commit is contained in:
bssrdf 2025-11-10 13:13:36 -05:00
parent d2d814c156
commit a428feecdd
2 changed files with 32 additions and 23 deletions

View File

@ -681,11 +681,11 @@ __device__ __forceinline__ void ldmatrix_b(
#endif
}
template<const int BM, const int BN, const int BK, const int WM, const int WN,
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
const int WK, const int NUM_THREADS>
static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
const half * __restrict__ kernel,
half * __restrict__ output,
T * __restrict__ output,
const param_t param) {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@ -828,27 +828,36 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
uint32_t (&reg_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
uint idx = output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
idx = idx ^ ((idx & 0b110000000000) >> 9);
idx = idx ^ ((idx & 0b1110000000) >> 4);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
dst_ptr[0] = reg_[0];
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx + 8 * BN / 2]);
idx = (idx + 8 * BN / 2 ) ^ 0b010;
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
dst_ptr[0] = reg_[1];
}
}
__syncthreads();
#pragma unroll
for (int subk = 0; subk < WN / 2; ++subk){
for (int subk = 0; subk < WN / 4; ++subk){
const uint row = m_idx + subk*2 + i * WN / 2;
uint idx = output_lds_addr + subk*2; // + j*32*BN/2;
idx = idx ^ ((idx & 0b110000000000) >> 9);
idx = idx ^ ((idx & 0b1110000000) >> 4);
for (int j = 0; j < 4; ++j){
const uint row = m_idx + subk + i * WN / 2;
const uint gemm_i = n_idx + j*32;
const int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const int col = fastmodulo(gemm_i, param.PQZ_fastdiv);
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*32*BN/2]));
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
if(n < param.n && row < param.k && col < PQZ){
const uint outOffset = (n * param.k + row) * PQZ + col;
uint idx = output_lds_addr + subk + j*32*BN/2;
idx = idx ^ ((idx & 0b1110000000) >> 4);
output[outOffset] = smemoutput[idx];
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
}
if(n < param.n && row+1 < param.k && col < PQZ){
const uint outOffset = (n * param.k + row + 1) * PQZ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
}
}
}
@ -924,15 +933,17 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
ne = P.c * P.r * P.s * P.t * P.k;
ne01 = P.r * P.s * P.t;
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id));
if(ne01 > 1){
kernel_f16.alloc(ne);
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
}
const half *X_H = input_f16.get();
const half *K_H = kernel_f16.get();
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Od *P.Oh * P.Ow * P.n);
const half *K_H = ne01 == 1 ? K_D : kernel_f16.get();
constexpr unsigned int BM_dim = 256;
constexpr unsigned int BN_dim = 256;
@ -952,16 +963,14 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
cudaFuncSetAttribute(conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow *P.Od * P.n, st);
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_D, P);
} else{
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
}

View File

@ -350,9 +350,9 @@ int main(void)
{
ggml_time_init();
std::vector<std::tuple<int, int, int, int, int, int, int, int>> configs = {
std::make_tuple(1,2,16,32,4,3,3,3),
// std::make_tuple(1,2,16,32,4,3,3,3),
// std::make_tuple(320,1280,26,38,8,3,3,3),
// std::make_tuple(1280,1280,26,38,8,3,3,3),
std::make_tuple(1280,1280,26,38,8,3,3,3),
// std::make_tuple(320,1280,52,76,8,3,3,3),
// std::make_tuple(1280,1280,52,76,8,3,3,3),
// std::make_tuple(320,1280,104,152,8,3,3,3),
@ -380,7 +380,7 @@ int main(void)
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
int iterations = 20;
int iterations = 0;
double run_time0;
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,