From 8a589317b6bfe60d732a1a0d1b9bb153145f9fd2 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 2 Sep 2025 22:47:41 -0400 Subject: [PATCH] Add implicit GEMM convolution operation for 2D tensors in CUDA --- ggml/include/ggml.h | 12 + ggml/src/ggml-cuda/conv2d-implicit.cu | 394 +++++++++++++++++++++++++ ggml/src/ggml-cuda/conv2d-implicit.cuh | 5 + ggml/src/ggml.c | 39 +++ 4 files changed, 450 insertions(+) create mode 100644 ggml/src/ggml-cuda/conv2d-implicit.cu create mode 100644 ggml/src/ggml-cuda/conv2d-implicit.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7e9c3c8c7a..d37a0a91ff 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -512,6 +512,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_CONV_2D, + GGML_OP_CONV_2D_IMPLICIT, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, @@ -1941,6 +1942,17 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_2d_implicitgemm( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_3d( struct ggml_context * ctx, struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu new file mode 100644 index 0000000000..d1b1dc7d3c --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -0,0 +1,394 @@ +#include "conv2d-implicit.cuh" +#include "convert.cuh" + +struct conv_params { + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; +}; + +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); + } +}; + +template +static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + + __shared__ __align__(16 * 1024) char smem[24 * 1024]; + T *smemweight = reinterpret_cast(smem); + float *smeminput = reinterpret_cast(smem + 16 * 1024); + + int tx = threadIdx.x; + int bx = blockIdx.x; + int by = blockIdx.y; + + // Warp tile + const int lane_id = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + const int mma_tid_x = (lane_id / 2) % 8; + const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2); + // lds addr + int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; + int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + + int x = bx * 128 + input_lds_addr; + int y = by * 128 + weight_lds_addr; + int z = blockIdx.z; + + T weight_ldg_reg[4]; + float input_ldg_reg[4]; + + int posh_ori[4]; + int posw_ori[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; + posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; + } + + int inOffset = z * param.c * param.h * param.w; + int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s; + int inChannelOffset = param.h * param.w; + int weightChannelOffset = param.r * param.s; + int weightKOffset = param.c * param.r * param.s; + + // sts addr + int weight_sts_addr = (tx % 8) * 132 + + (tx / 8) * 4; + int input_sts_addr = (tx / 32) * 128 + (tx % 32); + + int write_flag = 1; + T weight_frag[2][8]; + float input_frag[2][8]; + float output_frag[8][8]; +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] = 0; + } + } +// ldg +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) + { + weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; + } + else + { + weight_ldg_reg[i] = (T)0.f; + } + } + int curC = (tx / 32) / (param.r * param.s); // channel offset + int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset +#pragma unroll + for (int i = 0; i < 4; ++i) + { + int curH = posh_ori[i] + curR; // input h + int curW = posw_ori[i] + curS; // input w + int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + { + input_ldg_reg[i] = input[inOffset + inOffsetTmp]; + } + else + { + input_ldg_reg[i] = 0.0; + } + } + // sts + for (int i = 0; i < 4; ++i) + { + smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; + } + for (int i = 0; i < 4; ++i) + { + smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; + } + + __syncthreads(); + // lds +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[0][i] = smemweight[weight_lds_addr + i]; + weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[0][i] = smeminput[input_lds_addr + i]; + input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; + } + for (int crs = 0; crs < param.r * param.s * param.c; crs += 8) + { + // ldg + int weiOffsetTmp = crs + 8 + tx % 8; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) + { + weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; + } + else + { + weight_ldg_reg[i] = (T)0.f; + } + } + curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset + curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + +#pragma unroll + for (int i = 0; i < 4; ++i) + { + int curH = posh_ori[i] + curR; // input h + int curW = posw_ori[i] + curS; // input w + int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + { + input_ldg_reg[i] = input[inOffset + inOffsetTmp]; + } + else + { + input_ldg_reg[i] = 0.f; + } + } + int load_flag = write_flag ^ 1; +#pragma unroll + for (int subcrs = 0; subcrs < 8 - 1; ++subcrs) + { +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; + weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; + input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; + } + +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; + } + } + } + // sts + for (int i = 0; i < 4; ++i) + { + smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; + } + for (int i = 0; i < 4; ++i) + { + smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; + } + __syncthreads(); + write_flag ^= 1; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; + weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; + input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; + } +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; + } + } + } + + // reuse smem + float *smemoutput = reinterpret_cast(smem); + // float *smembias = reinterpret_cast(smem + 16 * 1024); + + // bias ldg/sts + // if (tx < 128) + // { + // smembias[tx] = param.bias[by * 128 + tx]; + // } + + uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; + uint32_t output_lds_addr = warp_id * 512 + lane_id; + // uint32_t bias_lds_addr = warp_id / 2 * 32; + + uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; + uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; + +#pragma unroll + for (int i = 0; i < 2; ++i) + { +#pragma unroll + for (int j = 0; j < 2; ++j) + { + __syncthreads(); + +#pragma unroll + for (int subi = 0; subi < 4; ++subi) + { +#pragma unroll + for (int subj = 0; subj < 4; ++subj) + { + // output sts + smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; + } + } + __syncthreads(); + +#pragma unroll + for (int subk = 0; subk < 16; ++subk) + { + int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; + if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk]; + output[outOffset] = smemoutput[output_lds_addr + subk * 32]; + } + } + } + +} + +template +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); +} + +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +} + +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +} + +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + // same number of input channels + GGML_ASSERT(input->ne[2] == kernel->ne[2]); + + cudaStream_t st = ctx.stream(); + + const int32_t * p = (const int32_t *) dst->op_params; + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y + + // No cwhn + GGML_ASSERT(p[6] == false); + + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + } else { + conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st); + } +} diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh new file mode 100644 index 0000000000..46161feb3c --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d76ea58f78..4e0fd672bd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4482,6 +4482,45 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } + +// ggml_conv_2d_implicitgemm + +struct ggml_tensor * ggml_conv_2d_implicitgemm( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1) {// dilation dimension 1 + + GGML_ASSERT(a->ne[2] == b->ne[2]); + //GGML_ASSERT(a->type == b->type); + + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + ne[2] = a->ne[3]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + ggml_set_op_params_i32(result, 0, s0); + ggml_set_op_params_i32(result, 1, s1); + ggml_set_op_params_i32(result, 2, p0); + ggml_set_op_params_i32(result, 3, p1); + ggml_set_op_params_i32(result, 4, d0); + ggml_set_op_params_i32(result, 5, d1); + + result->op = GGML_OP_CONV_2D_IMPLICIT; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_3d struct ggml_tensor * ggml_conv_3d(