From 80a996cfc0019e615af23348072c735988357ee1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 24 Oct 2025 11:41:11 -0400 Subject: [PATCH] WIP: tensore code compiled ok --- ggml/src/ggml-cuda/conv2d-implicit.cu | 184 ++++++++++++++----------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 33 +++-- 2 files changed, 126 insertions(+), 91 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 06bb4c53f1..482270e2c7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -730,7 +730,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + template __device__ __forceinline__ void ldmatrix_a( @@ -738,6 +738,7 @@ __device__ __forceinline__ void ldmatrix_a( half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] ) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); @@ -880,7 +881,11 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) : "r"(src_addr + 96 * smem_stride_) ); - +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif } template @@ -889,10 +894,11 @@ __device__ __forceinline__ void ldmatrix_b( half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] ) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); -// uint32_t (®_) [4][8] = reinterpret_cast(reg); + uint32_t (®_) [4][8] = reinterpret_cast(reg); // const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); // unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); // uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); @@ -979,15 +985,20 @@ __device__ __forceinline__ void ldmatrix_b( // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); - +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif } template -static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, +static __global__ void conv2d_implicit_kernel_tc(const half * __restrict__ input, const half * __restrict__ kernel, half * __restrict__ output, const param_t param) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; @@ -997,18 +1008,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint weightKOffset = param.c * param.r * param.s; // for convenience/readability in index calculations - const unsigned int A_stride = K; - const unsigned int B_stride = N; - const unsigned int CD_stride = N; +// const unsigned int A_stride = K; +// const unsigned int B_stride = N; +// const unsigned int CD_stride = N; // calculate how many bits of shared memory indices are going to be swizzled, and create masks - constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8); +// constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8); // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; - const unsigned int num_block_tiles_k = K / BK; + const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -1057,8 +1068,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory // half* A_block_gmem = input + (block_m * BM * A_stride); - half* A_block_gmem = input; - half* B_block_gmem = kernel + (block_n * weightKOffset); + const half* A_block_gmem = input; + const half* B_block_gmem = kernel + (block_n * weightKOffset); tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); @@ -1074,10 +1085,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k) { // half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); - half* A_block_gmem = input; - half* B_block_gmem = kernel + (block_n * weightKOffset); - tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); - tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); + const half* A_block_gmem = input; + const half* B_block_gmem = kernel + (block_n * weightKOffset); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); half* B_warp_tile = B_block_smem + (warp_n * WN * BK); @@ -1124,14 +1135,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } // reuse smem - half *smemoutput = reinterpret_cast(shmem); + half *smemoutput = shmem; const uint lane_id = threadIdx.x % WARPSIZE; const uint mma_row = lane_id / 4; const uint mma_col = lane_id % 4; - const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2; const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2; - const uint m_idx = by * BN + mma_tid_y * WN; - const uint n_idx = block_m * BM + warp_m * WM; + const uint m_idx = block_n * BN + warp_n * WN; + const uint n_idx = block_m * BM + warp_m * WM + lane_id; #pragma unroll for (int i = 0; i < 2; ++i) @@ -1142,72 +1153,42 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, { // output sts uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint32_t* dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]); + const uint idx = output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; - dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]); + dst_ptr = reinterpret_cast(&smemoutput[idx + 8 * BN / 2]); dst_ptr[0] = reg_[1]; } } __syncthreads(); + #pragma unroll - const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; - const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; - const int n = fastdiv(gemm_i, param.OHOW_fastdiv); - const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - if (n < param.n && row < param.k && col < param.Oh * param.Ow){ - // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); - // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) - // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; - const uint outOffset = ksplit > 0 ? - z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + - row * param.Oh * param.Ow + col : - z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; - output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + for (int subk = 0; subk < WN / 2; ++subk){ + for (int j = 0; j < 4; ++j){ + const uint row = m_idx + subk + i * WN / 2; + const uint gemm_i = n_idx + j*32; + const int n = fastdiv(gemm_i, param.OHOW_fastdiv); + const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + if(n < param.n && row < param.k && col < param.Oh * param.Ow){ + // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); + // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; + const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2]; + } + } } } - - ////////////// - // epilogue // - ////////////// -// half alpha_ = (half)alpha; -// half beta_ = (half)beta; -// half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; - -// // calculate pointers for this warps C and D tiles -// half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); -// half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); -// half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); -// half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); - -// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) -// { -// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) -// { -// half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); -// ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); - -// // scale C by beta -// acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; -// acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; -// acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; -// acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; -// } -// } - -// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) -// { -// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) -// { -// half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); -// stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); -// } -// } - +#else + GGML_UNUSED(input); + GGML_UNUSED(kernel); + GGML_UNUSED(output); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif } -#endif #define NUM_VARIANTS 6 @@ -1266,11 +1247,53 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, } } -static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) { + constexpr unsigned int BM_dim = 256; + constexpr unsigned int BN_dim = 256; + constexpr unsigned int BK_dim = 32; + + constexpr unsigned int WARPS_PER_BLOCK_M = 2; + constexpr unsigned int WARPS_PER_BLOCK_N = 4; + constexpr unsigned int WARPS_PER_BLOCK_K = 4; + + constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; + constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; + constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; + const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; + constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; + constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; + constexpr unsigned int NumThreads = ThreadsM * ThreadsN; + const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + int id = ggml_cuda_get_device(); + ggml_cuda_pool_alloc x_f16(ctx.pool(id)); + + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(GGML_TYPE_F32); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = P.c * P.h * P.w * P.n; + x_f16.alloc(ne); + to_fp16_cuda(X_D, x_f16.get(), ne, st); + const half *X_H = x_f16.get(); + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + conv2d_implicit_kernel_tc + <<>>(X_H, K_D, Y_H.get(), P); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + }else{ + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + } +#else conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +#endif } -static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { +static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } @@ -1286,6 +1309,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t st = ctx.stream(); + const int cc = ggml_cuda_info().devices[ctx.device].cc; const int32_t * p = (const int32_t *) dst->op_params; const int ST_X = p[0]; // stride_x @@ -1333,8 +1357,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * params.layout = LT; if (kernel->type == GGML_TYPE_F16) { - conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); } else { - conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st); + conv2d_implicit_cuda_f32(ctx, X_D, K_D, Y_D, cc, params, st); } } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 1a54a184a8..b7d2c8ff2e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -31,9 +31,10 @@ typedef struct{ template __device__ __forceinline__ void tileMemcpySwizzleB( - half* src, + const half* src, half* dst, - const unsigned int src_stride + const unsigned int src_stride, + param_t param ) { // constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; @@ -109,7 +110,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -123,7 +124,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( template __device__ __forceinline__ void tileMemcpySwizzleA( - half* src, + const half* src, half* dst, // const unsigned int src_stride, const unsigned int inChannelOffset, @@ -175,7 +176,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.R && curS < param.S && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c){ // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; @@ -191,7 +192,7 @@ unsigned int TILE_COLS, unsigned int NUM_THREADS, unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadA( - half* src, + const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], // const unsigned int src_stride, const unsigned int block_k, @@ -200,7 +201,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( ) { // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // const float4* src_float4 = reinterpret_cast(src); // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile @@ -239,7 +240,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.R && curS < param.S && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c){ // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; @@ -256,7 +257,7 @@ unsigned int TILE_COLS, unsigned int NUM_THREADS, unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadB( - half* src, + const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int block_k, const unsigned int src_stride, @@ -264,7 +265,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( ) { // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // const float4* src_float4 = reinterpret_cast(src); // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile @@ -295,7 +296,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( // dst_reg[i] = src_float4[src_index]; // thread_row += ROW_STEP; const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; - if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -449,5 +450,15 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { #endif +// constexpr unsigned int int_log2(unsigned int x) +// { +// unsigned int result = 0; +// while (x >>= 1) +// { +// result++; +// } +// return result; +// } + #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);