From f0a480cc221aedf106f015410341bb22a740c370 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 21 Oct 2025 15:43:35 -0400 Subject: [PATCH] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 176 ++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 93ede3efc8..174b9b46ba 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -730,6 +730,182 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } +template +static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, + const half * __restrict__ kernel, + float * __restrict__ output, + const param_t param) { + constexpr unsigned int MMA_M = 16; + constexpr unsigned int MMA_N = 8; + + const unsigned int K = param.c * param.r * param.s; + + // for convenience/readability in index calculations + const unsigned int A_stride = K; + const unsigned int B_stride = N; + const unsigned int CD_stride = N; + + // calculate how many bits of shared memory indices are going to be swizzled, and create masks + constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8); + + // loop bounds, constexpr where possible allows for loop unrolling + constexpr unsigned int mma_tiles_per_warp_k = 4; + constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; + constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; + const unsigned int num_block_tiles_k = K / BK; + + // calculate block/warp indices + const unsigned int block_m = blockIdx.y; + const unsigned int block_n = blockIdx.x; + const unsigned int warp_m = threadIdx.y; + const unsigned int warp_n = threadIdx.x / 32; + + // double buffering + extern __shared__ half shmem[]; + half* A_block_smem = shmem; + half* B_block_smem = &shmem[BM * BK]; + constexpr int BUFFER_SIZE = BM * BK + BK * BN; + + // declare register storage + // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together + uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; + + // convenience cast to half for register storage + half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); + + // accumulators start at 0 + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + acc_register_[mma_m][mma_n][0] = 0; + acc_register_[mma_m][mma_n][1] = 0; + acc_register_[mma_m][mma_n][2] = 0; + acc_register_[mma_m][mma_n][3] = 0; + } + } + + // these register arrays are used to cache values pre-fetched from global memory during the inner loop of the kernel + // the code is nicer if we hard code it for these tile dimensions and number of threads + // since we performing this copy with float4 pointers, for these tile dimensions it works out to be 8 float4s for A and 4 float4s for B + static_assert(BM_dim == 256); + static_assert(BN_dim == 256); + static_assert(BK_dim == 32); + static_assert(NUM_THREADS == 256); + float4 A_gmem_cache_reg[4]; + float4 B_gmem_cache_reg[4]; + + // prefetch the first block tile of A,B into shared memory + half* A_block_gmem = A + (block_m * BM_dim * A_stride); + half* B_block_gmem = B + (block_n * BN_dim); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); + tileMemcpySwizzle(B_block_gmem, B_block_smem, N); + + // construct const pointers to warp tiles for use inside the inner loop + + + int offset_direction = 1; + + for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++) + { + __syncthreads(); + + if (block_k != num_block_tiles_k) + { + half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim); + half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim); + tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); + tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); + } + half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim); + half* B_warp_tile = B_block_smem + (warp_n * WN_dim); + + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); + + // outer product between mma tiles + #pragma unroll + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) + { + #pragma unroll + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + #pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); + } + } + } + + + if (block_k != num_block_tiles_k) + { + // switch smem buffers each iteration + A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction; + B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; + offset_direction = -1 * offset_direction; + + tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + } + } + + ////////////// + // epilogue // + ////////////// + half alpha_ = (half)alpha; + half beta_ = (half)beta; + half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; + + // calculate pointers for this warps C and D tiles + half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); + half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); + half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); + half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); + + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); + ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); + + // scale C by beta + acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; + acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; + acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; + acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; + } + } + + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); + stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); + } + } + +} + + #define NUM_VARIANTS 6 /*