From 8a589317b6bfe60d732a1a0d1b9bb153145f9fd2 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 2 Sep 2025 22:47:41 -0400 Subject: [PATCH 001/122] Add implicit GEMM convolution operation for 2D tensors in CUDA --- ggml/include/ggml.h | 12 + ggml/src/ggml-cuda/conv2d-implicit.cu | 394 +++++++++++++++++++++++++ ggml/src/ggml-cuda/conv2d-implicit.cuh | 5 + ggml/src/ggml.c | 39 +++ 4 files changed, 450 insertions(+) create mode 100644 ggml/src/ggml-cuda/conv2d-implicit.cu create mode 100644 ggml/src/ggml-cuda/conv2d-implicit.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7e9c3c8c7a..d37a0a91ff 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -512,6 +512,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_CONV_2D, + GGML_OP_CONV_2D_IMPLICIT, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, @@ -1941,6 +1942,17 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_2d_implicitgemm( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_3d( struct ggml_context * ctx, struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu new file mode 100644 index 0000000000..d1b1dc7d3c --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -0,0 +1,394 @@ +#include "conv2d-implicit.cuh" +#include "convert.cuh" + +struct conv_params { + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; +}; + +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); + } +}; + +template +static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + + __shared__ __align__(16 * 1024) char smem[24 * 1024]; + T *smemweight = reinterpret_cast(smem); + float *smeminput = reinterpret_cast(smem + 16 * 1024); + + int tx = threadIdx.x; + int bx = blockIdx.x; + int by = blockIdx.y; + + // Warp tile + const int lane_id = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + const int mma_tid_x = (lane_id / 2) % 8; + const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2); + // lds addr + int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; + int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + + int x = bx * 128 + input_lds_addr; + int y = by * 128 + weight_lds_addr; + int z = blockIdx.z; + + T weight_ldg_reg[4]; + float input_ldg_reg[4]; + + int posh_ori[4]; + int posw_ori[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; + posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; + } + + int inOffset = z * param.c * param.h * param.w; + int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s; + int inChannelOffset = param.h * param.w; + int weightChannelOffset = param.r * param.s; + int weightKOffset = param.c * param.r * param.s; + + // sts addr + int weight_sts_addr = (tx % 8) * 132 + + (tx / 8) * 4; + int input_sts_addr = (tx / 32) * 128 + (tx % 32); + + int write_flag = 1; + T weight_frag[2][8]; + float input_frag[2][8]; + float output_frag[8][8]; +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] = 0; + } + } +// ldg +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) + { + weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; + } + else + { + weight_ldg_reg[i] = (T)0.f; + } + } + int curC = (tx / 32) / (param.r * param.s); // channel offset + int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset +#pragma unroll + for (int i = 0; i < 4; ++i) + { + int curH = posh_ori[i] + curR; // input h + int curW = posw_ori[i] + curS; // input w + int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + { + input_ldg_reg[i] = input[inOffset + inOffsetTmp]; + } + else + { + input_ldg_reg[i] = 0.0; + } + } + // sts + for (int i = 0; i < 4; ++i) + { + smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; + } + for (int i = 0; i < 4; ++i) + { + smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; + } + + __syncthreads(); + // lds +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[0][i] = smemweight[weight_lds_addr + i]; + weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[0][i] = smeminput[input_lds_addr + i]; + input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; + } + for (int crs = 0; crs < param.r * param.s * param.c; crs += 8) + { + // ldg + int weiOffsetTmp = crs + 8 + tx % 8; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) + { + weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; + } + else + { + weight_ldg_reg[i] = (T)0.f; + } + } + curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset + curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + +#pragma unroll + for (int i = 0; i < 4; ++i) + { + int curH = posh_ori[i] + curR; // input h + int curW = posw_ori[i] + curS; // input w + int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + { + input_ldg_reg[i] = input[inOffset + inOffsetTmp]; + } + else + { + input_ldg_reg[i] = 0.f; + } + } + int load_flag = write_flag ^ 1; +#pragma unroll + for (int subcrs = 0; subcrs < 8 - 1; ++subcrs) + { +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; + weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; + input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; + } + +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; + } + } + } + // sts + for (int i = 0; i < 4; ++i) + { + smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; + } + for (int i = 0; i < 4; ++i) + { + smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; + } + __syncthreads(); + write_flag ^= 1; +#pragma unroll + for (int i = 0; i < 4; ++i) + { + weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; + weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; + } +#pragma unroll + for (int i = 0; i < 4; ++i) + { + input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; + input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; + } +#pragma unroll + for (int i = 0; i < 8; ++i) + { +#pragma unroll + for (int j = 0; j < 8; ++j) + { + output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; + } + } + } + + // reuse smem + float *smemoutput = reinterpret_cast(smem); + // float *smembias = reinterpret_cast(smem + 16 * 1024); + + // bias ldg/sts + // if (tx < 128) + // { + // smembias[tx] = param.bias[by * 128 + tx]; + // } + + uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; + uint32_t output_lds_addr = warp_id * 512 + lane_id; + // uint32_t bias_lds_addr = warp_id / 2 * 32; + + uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; + uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; + +#pragma unroll + for (int i = 0; i < 2; ++i) + { +#pragma unroll + for (int j = 0; j < 2; ++j) + { + __syncthreads(); + +#pragma unroll + for (int subi = 0; subi < 4; ++subi) + { +#pragma unroll + for (int subj = 0; subj < 4; ++subj) + { + // output sts + smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; + } + } + __syncthreads(); + +#pragma unroll + for (int subk = 0; subk < 16; ++subk) + { + int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; + if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk]; + output[outOffset] = smemoutput[output_lds_addr + subk * 32]; + } + } + } + +} + +template +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); +} + +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +} + +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +} + +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + // same number of input channels + GGML_ASSERT(input->ne[2] == kernel->ne[2]); + + cudaStream_t st = ctx.stream(); + + const int32_t * p = (const int32_t *) dst->op_params; + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y + + // No cwhn + GGML_ASSERT(p[6] == false); + + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + } else { + conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st); + } +} diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh new file mode 100644 index 0000000000..46161feb3c --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d76ea58f78..4e0fd672bd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4482,6 +4482,45 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } + +// ggml_conv_2d_implicitgemm + +struct ggml_tensor * ggml_conv_2d_implicitgemm( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1) {// dilation dimension 1 + + GGML_ASSERT(a->ne[2] == b->ne[2]); + //GGML_ASSERT(a->type == b->type); + + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + ne[2] = a->ne[3]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + ggml_set_op_params_i32(result, 0, s0); + ggml_set_op_params_i32(result, 1, s1); + ggml_set_op_params_i32(result, 2, p0); + ggml_set_op_params_i32(result, 3, p1); + ggml_set_op_params_i32(result, 4, d0); + ggml_set_op_params_i32(result, 5, d1); + + result->op = GGML_OP_CONV_2D_IMPLICIT; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_3d struct ggml_tensor * ggml_conv_3d( From 4d772873b94641386a48f923cead6aca618e0d8e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 3 Sep 2025 11:29:14 -0400 Subject: [PATCH 002/122] Add implicit convolution support for 2D tensors in CPU and CUDA implementations --- ggml/src/ggml-cpu/ggml-cpu.c | 6 ++ ggml/src/ggml-cuda/conv2d-implicit.cu | 118 +++++++++----------------- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++ ggml/src/ggml.c | 4 +- 4 files changed, 53 insertions(+), 80 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0d5d3a3440..16d9f0204a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_2d(params, tensor); } break; + case GGML_OP_CONV_2D_IMPLICIT: + { + ggml_compute_forward_conv_2d(params, tensor); + } break; case GGML_OP_CONV_3D: { ggml_compute_forward_conv_3d(params, tensor); @@ -2256,6 +2260,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: @@ -2778,6 +2783,7 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: { cur = GGML_IM2COL_WORK_SIZE; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d1b1dc7d3c..72f8d30baf 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,81 +1,33 @@ #include "conv2d-implicit.cuh" #include "convert.cuh" -struct conv_params { - const int64_t IW, IH; - const int64_t OW, OH; - const int64_t KW, KH; - const int64_t ST_X, ST_Y; - const int64_t PD_X, PD_Y; - const int64_t DL_X, DL_Y; - const int64_t IC, OC; - const int64_t B; - const int64_t TOTAL; -}; +typedef struct{ + unsigned int n; //batch szie + unsigned int c; //channel number + unsigned int h; //height + unsigned int w; //width + unsigned int k; //number of filters + unsigned int r; //filter height + unsigned int s; //filter width + unsigned int u; //stride height + unsigned int v; //stride width + unsigned int p; //padding height + unsigned int q; //padding width + unsigned int d_h; //dilation height + unsigned int d_w; //dilation width + unsigned int Oh; //output height + unsigned int Ow; //output width +} param_t; -struct kernel_bounds { - int64_t y_min, y_max; - int64_t x_min, x_max; -}; -__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { - return (a > b) ? a : b; -} -__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { - return (a < b) ? a : b; -} - -__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { - kernel_bounds bounds; - bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - return bounds; -} - -__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, - int64_t kern_coord, - int64_t stride, - int64_t dilation, - int64_t padding) { - return out_coord * stride + kern_coord * dilation - padding; -} - -struct whcn_layout { - __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { - return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; - } - - __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { - return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; - } - - __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { - return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; - } - - __device__ static void unpack_indices(int64_t global_idx, - const conv_params & P, - int64_t & n, - int64_t & c, - int64_t & out_y, - int64_t & out_x) { - out_x = global_idx % P.OW; - out_y = (global_idx / P.OW) % P.OH; - c = (global_idx / (P.OW * P.OH)) % P.OC; - n = global_idx / (P.OW * P.OH * P.OC); - } -}; - -template +template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, - const conv_params P) { + const param_t ¶m) { - __shared__ __align__(16 * 1024) char smem[24 * 1024]; + extern __shared__ __align__(16 * 1024) char smem[]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 16 * 1024); @@ -151,8 +103,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #pragma unroll for (int i = 0; i < 4; ++i) { - int curH = posh_ori[i] + curR; // input h - int curW = posw_ori[i] + curS; // input w + int curH = posh_ori[i] + curR * param.d_h; // input h + int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) { @@ -210,8 +162,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #pragma unroll for (int i = 0; i < 4; ++i) { - int curH = posh_ori[i] + curR; // input h - int curW = posw_ori[i] + curS; // input w + int curH = posh_ori[i] + curR * param.d_h; // input h + int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) { @@ -334,16 +286,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) { + // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number + int blocky = (P.k + 127) / 128; // blocky number + int blockz = P.n; // blockz number + int threadx = CUDA_CONV2D_IMPLICT_BLOCK_SIZE; // threadx number per block + int thready = 1; // thready number per block + int threadz = 1; // threadz number per block + dim3 thblock(threadx, thready, threadz); + dim3 grid(blockx, blocky, blockz); + int smem_size = 24 * 1024; + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } @@ -384,7 +345,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int B = input->ne[3]; // n_batches const int64_t total = B * OC * OH * OW; - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + // param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW }; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e06f95f081..0b799fbaf1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -13,6 +13,7 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" +#include "ggml-cuda/conv2d-implicit.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2455,6 +2456,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_2D: ggml_cuda_op_conv2d(ctx, dst); break; + case GGML_OP_CONV_2D_IMPLICIT: + ggml_cuda_op_conv2d_implicit(ctx, dst); + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -3560,6 +3564,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } case GGML_OP_IM2COL: case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4e0fd672bd..69003dfc5c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1018,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1121,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); From 3877608dc05e86e824ada455b0cf36f759c04192 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 3 Sep 2025 12:45:19 -0400 Subject: [PATCH 003/122] fix passing param as reference --- ggml/src/ggml-cuda/conv2d-implicit.cu | 25 ++++--- ggml/src/ggml.c | 2 + tests/test-backend-ops.cpp | 99 +++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 72f8d30baf..a78720ecc6 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -25,9 +25,9 @@ template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, - const param_t ¶m) { + const param_t param) { - extern __shared__ __align__(16 * 1024) char smem[]; + extern __shared__ __align__(16 * 1024) char smem[]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 16 * 1024); @@ -35,6 +35,12 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int bx = blockIdx.x; int by = blockIdx.y; + // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ + // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); + // // printf("param.n=%d\n",param.n); + // } + // __syncthreads(); + // Warp tile const int lane_id = threadIdx.x % 32; const int warp_id = threadIdx.x / 32; @@ -85,6 +91,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } // ldg + // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ + // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); + // } + // __syncthreads(); #pragma unroll for (int i = 0; i < 4; ++i) { @@ -282,11 +292,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } } - } template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t &P, cudaStream_t st) { +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number int blocky = (P.k + 127) / 128; // blocky number @@ -300,11 +309,11 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t &P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t &P, cudaStream_t st) { +static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } @@ -343,9 +352,9 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int IC = input->ne[2]; // input_channels const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - + const int64_t total = B * OC * OH * OW; - // param_t params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW }; if (kernel->type == GGML_TYPE_F16) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 69003dfc5c..cdf13a1370 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL", "IM2COL_BACK", "CONV_2D", + "CONV_2D_IMPLICIT", "CONV_3D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", @@ -1078,6 +1079,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col(x)", "im2col_back(x)", "conv_2d(x)", + "conv_2d_implicit(x)", "conv_3d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a58621094..9ab73434fe 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4116,6 +4116,94 @@ struct test_conv_2d : public test_case { } }; +// CONV_2D_IMPLICIT +struct test_conv_2d_implicit : public test_case { + const std::array ne_input; + const std::array ne_kernel; + const ggml_type type_kernel; + const int stride0; + const int stride1; + const int padding0; + const int padding1; + const int dilation0; + const int dilation1; + // Whether the inputs are contiguous in the channel dim or the width dim + const bool cwhn; + + + + std::string vars() override { + return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); + } + + double max_nmse_err() override { + return 5e-4; + } + + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + // Just counting matmul costs: + // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + + int64_t W = ne_input[0]; + int64_t H = ne_input[1]; + int64_t KW = ne_kernel[0]; + int64_t KH = ne_kernel[1]; + int64_t Cin = ne_kernel[2]; + int64_t Cout = ne_kernel[3]; + int64_t N = ne_input[3]; + int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0); + int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0); + + int64_t K = Cout; + int64_t CRS = Cin * KH * KW; + int64_t NPQ = N * OH * OW; + + return K * NPQ * (2 * CRS - 1); + } + + test_conv_2d_implicit(std::array ne_input = { 64, 64, 16, 1 }, + std::array ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1, + int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) : + ne_input(ne_input), + ne_kernel(ne_kernel), + type_kernel(type_kernel), + stride0(stride0), + stride1(stride1), + padding0(padding0), + padding1(padding1), + dilation0(dilation0), + dilation1(dilation1), + cwhn(cwhn) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + if (cwhn) { + // change memory layout to channel-most-contiguous (CWHN), + // then permute it back so NE matches the original input + input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); + input = ggml_permute(ctx, input, 2, 0, 1, 3); + kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); + kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); + } + + ggml_tensor * out = + ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1); + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_CONV_2D_DW struct test_conv_2d_dw : public test_case { const std::array ne_input; @@ -6454,6 +6542,17 @@ static std::vector> make_test_cases_perf() { } } + for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (auto act_case : cases) { + // Direct CONV_2D + test_cases.emplace_back(new test_conv_2d_implicit( + { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, + kernel_type, 1, 1, 0, 0, 1, 1, false)); + } + } + + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); From 6d84cbb5abc2f7f3590c9ec3c5b01496543ec593 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 3 Sep 2025 15:45:09 -0400 Subject: [PATCH 004/122] Fix parameter order in conv2d_implicit and add comprehensive test cases for 2D convolution --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 +- tests/test-backend-ops.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index a78720ecc6..4f452ab98b 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -355,7 +355,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t total = B * OC * OH * OW; - param_t params = { B, IC, IH, IW, OC, KH, KW, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, OH, OW }; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9ab73434fe..d5e1005d2f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5790,6 +5790,30 @@ static std::vector> make_test_cases_eval() { } } + for (uint32_t s0 : { 1, 3 }) { + for (uint32_t p1 : { 2, 5 }) { + for (uint32_t Cin : { 1, 25 }) { + for (uint32_t Cout : { 1, 12 }) { + for (uint32_t KH : { 1, 2, 3, 11 }) { + for (uint32_t KW : { 1, 2, 3, 11 }) { + for (uint32_t H : { 1, 133 }) { + for (uint32_t W : { 1, 141 }) { + if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 && + calc_conv_output_size(H, KH, s1, p1, d1) > 0) { + for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_conv_2d_implicit( + { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false)); + } + } + } + } + } + } + } + } + } + } + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) From 5ffe97be9c35169aea1e451426eb53e5430f7d24 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 4 Sep 2025 15:32:29 -0400 Subject: [PATCH 005/122] Fix boundary check in conv2d_implicit_kernel to include channel limits --- ggml/src/ggml-cuda/conv2d-implicit.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 4f452ab98b..d9fabd9657 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -116,7 +116,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) { input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } @@ -175,7 +175,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h) + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) { input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } From 4b0f9d571f4166035ee72558e4710b6205893af7 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 5 Sep 2025 08:29:57 -0400 Subject: [PATCH 006/122] Refactor conv2d_implicit_kernel for improved readability and consistency; update parameter comments and remove unused code --- ggml/src/ggml-cuda/conv2d-implicit.cu | 143 ++++++++------------------ 1 file changed, 44 insertions(+), 99 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d9fabd9657..31205187c1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -2,8 +2,8 @@ #include "convert.cuh" typedef struct{ - unsigned int n; //batch szie - unsigned int c; //channel number + unsigned int n; //batch size + unsigned int c; //number if channels unsigned int h; //height unsigned int w; //width unsigned int k; //number of filters @@ -23,23 +23,18 @@ typedef struct{ template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, - const T * __restrict__ kernel, - float * __restrict__ output, - const param_t param) { + const T * __restrict__ kernel, + float * __restrict__ output, + const param_t param) { - extern __shared__ __align__(16 * 1024) char smem[]; + extern __shared__ unsigned char smem[]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 16 * 1024); int tx = threadIdx.x; int bx = blockIdx.x; int by = blockIdx.y; - - // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ - // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); - // // printf("param.n=%d\n",param.n); - // } - // __syncthreads(); + // Warp tile const int lane_id = threadIdx.x % 32; @@ -60,8 +55,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int posh_ori[4]; int posw_ori[4]; #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; } @@ -82,28 +76,19 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, float input_frag[2][8]; float output_frag[8][8]; #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] = 0; } } // ldg - // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ - // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); - // } - // __syncthreads(); #pragma unroll - for (int i = 0; i < 4; ++i) - { - if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) - { + for (int i = 0; i < 4; ++i){ + if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; } - else - { + else{ weight_ldg_reg[i] = (T)0.f; } } @@ -111,57 +96,46 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) - { + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } - else - { + else{ input_ldg_reg[i] = 0.0; } } // sts - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; } - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; } __syncthreads(); // lds #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[0][i] = smemweight[weight_lds_addr + i]; weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[0][i] = smeminput[input_lds_addr + i]; input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; } - for (int crs = 0; crs < param.r * param.s * param.c; crs += 8) - { + for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ // ldg int weiOffsetTmp = crs + 8 + tx % 8; #pragma unroll - for (int i = 0; i < 4; ++i) - { - if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) - { + for (int i = 0; i < 4; ++i){ + if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; } - else - { + else{ weight_ldg_reg[i] = (T)0.f; } } @@ -170,76 +144,62 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) - { + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } - else - { + else{ input_ldg_reg[i] = 0.f; } } int load_flag = write_flag ^ 1; #pragma unroll - for (int subcrs = 0; subcrs < 8 - 1; ++subcrs) - { + for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; } } } // sts - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; } - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; } __syncthreads(); write_flag ^= 1; #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; } #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; } } @@ -247,35 +207,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // reuse smem float *smemoutput = reinterpret_cast(smem); - // float *smembias = reinterpret_cast(smem + 16 * 1024); - // bias ldg/sts - // if (tx < 128) - // { - // smembias[tx] = param.bias[by * 128 + tx]; - // } uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; uint32_t output_lds_addr = warp_id * 512 + lane_id; - // uint32_t bias_lds_addr = warp_id / 2 * 32; uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; #pragma unroll - for (int i = 0; i < 2; ++i) - { + for (int i = 0; i < 2; ++i){ #pragma unroll - for (int j = 0; j < 2; ++j) - { + for (int j = 0; j < 2; ++j){ __syncthreads(); - #pragma unroll - for (int subi = 0; subi < 4; ++subi) - { + for (int subi = 0; subi < 4; ++subi){ #pragma unroll - for (int subj = 0; subj < 4; ++subj) - { + for (int subj = 0; subj < 4; ++subj){ // output sts smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; } @@ -283,11 +231,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, __syncthreads(); #pragma unroll - for (int subk = 0; subk < 16; ++subk) - { + for (int subk = 0; subk < 16; ++subk){ int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) - // output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk]; output[outOffset] = smemoutput[output_lds_addr + subk * 32]; } } @@ -295,8 +241,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { - // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number int blocky = (P.k + 127) / 128; // blocky number int blockz = P.n; // blockz number From 83a3b7d6a98727705ed04f29afdd587ea3c17c37 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 6 Sep 2025 17:26:19 -0400 Subject: [PATCH 007/122] Refactor conv2d_implicit_kernel for improved bitwise operations; add test for implicit convolution --- ggml/src/ggml-cuda/conv2d-implicit.cu | 72 +++-- tests/CMakeLists.txt | 1 + tests/test-conv2d-implicit.cpp | 390 ++++++++++++++++++++++++++ 3 files changed, 434 insertions(+), 29 deletions(-) create mode 100644 tests/test-conv2d-implicit.cpp diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 31205187c1..1e2540f8ca 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -37,16 +37,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // Warp tile - const int lane_id = threadIdx.x % 32; - const int warp_id = threadIdx.x / 32; - const int mma_tid_x = (lane_id / 2) % 8; - const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2); + const int lane_id = threadIdx.x & 31; + const int warp_id = threadIdx.x >> 5; + const int mma_tid_x = (lane_id >> 1) % 8; + const int mma_tid_y = (lane_id >> 4) * 2 + (lane_id & 1); // lds addr - int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; - int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + int weight_lds_addr = (warp_id >> 1) * 32 + mma_tid_y * 4; + int input_lds_addr = (warp_id & 1) * 64 + mma_tid_x * 4; - int x = bx * 128 + input_lds_addr; - int y = by * 128 + weight_lds_addr; + // int x = bx * 128 + input_lds_addr; + // int y = by * 128 + weight_lds_addr; int z = blockIdx.z; T weight_ldg_reg[4]; @@ -56,20 +56,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int posw_ori[4]; #pragma unroll for (int i = 0; i < 4; ++i){ - posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; - posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; + posh_ori[i] = ((bx * 128 + lane_id + i * 32) / param.Ow) * param.u - param.p; + posw_ori[i] = ((bx * 128 + lane_id + i * 32) % param.Ow) * param.v - param.q; } int inOffset = z * param.c * param.h * param.w; - int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s; + int weiOffset = (by * 128 + (tx >> 3) * 4) * param.c * param.r * param.s; int inChannelOffset = param.h * param.w; - int weightChannelOffset = param.r * param.s; + // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; // sts addr - int weight_sts_addr = (tx % 8) * 132 + - (tx / 8) * 4; - int input_sts_addr = (tx / 32) * 128 + (tx % 32); + int weight_sts_addr = (tx & 7) * 132 + + (tx >> 3) * 4; + int input_sts_addr = (warp_id) * 128 + (lane_id); int write_flag = 1; T weight_frag[2][8]; @@ -85,16 +85,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // ldg #pragma unroll for (int i = 0; i < 4; ++i){ - if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ - weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; + if (tx % 8 < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ + weight_ldg_reg[i] = kernel[weiOffset + (tx & 7) + i * weightKOffset]; } else{ weight_ldg_reg[i] = (T)0.f; } } - int curC = (tx / 32) / (param.r * param.s); // channel offset - int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset - int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + int curC = (warp_id) / (param.r * param.s); // channel offset + int curR = ((warp_id) % (param.r * param.s)) / param.s; // kernel r offset + int curS = ((warp_id) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h @@ -127,21 +127,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, input_frag[0][i] = smeminput[input_lds_addr + i]; input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; } + + // main loop for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ // ldg - int weiOffsetTmp = crs + 8 + tx % 8; + int weiOffsetTmp = crs + 8 + (tx & 7); #pragma unroll for (int i = 0; i < 4; ++i){ - if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ + if (weiOffsetTmp < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; } else{ weight_ldg_reg[i] = (T)0.f; } } - curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset - curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset - curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + curC = (crs + 8 + warp_id) / (param.r * param.s); // channel offset + curR = ((crs + 8 + warp_id) % (param.r * param.s)) / param.s; // kernel r offset + curS = ((crs + 8 + warp_id) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll for (int i = 0; i < 4; ++i){ @@ -160,13 +162,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ #pragma unroll for (int i = 0; i < 4; ++i){ - weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; - weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; + weight_frag[(subcrs + 1) & 1][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; + weight_frag[(subcrs + 1) & 1][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; } + // // compute base pointer once + // T* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; + + // // first 4 values -> weight_frag[...][0..3] + // float4 v0 = *reinterpret_cast(base_ptr); + + // // next 4 values (offset +16) -> weight_frag[...][4..7] + // float4 v1 = *reinterpret_cast(base_ptr + 16); + + // // unpack into weight_frag + // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][0]) = v0; + // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][4]) = v1; #pragma unroll for (int i = 0; i < 4; ++i){ - input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; - input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; + input_frag[(subcrs + 1) & 1][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; + input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } #pragma unroll diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9171957756..7ce76f0105 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -198,6 +198,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS) endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) +llama_build_and_test(test-conv2d-implicit.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp new file mode 100644 index 0000000000..b0efba2f1c --- /dev/null +++ b/tests/test-conv2d-implicit.cpp @@ -0,0 +1,390 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +//#include +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + + + +void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu = false ) { + // create data + int KW = 3, KH = 3, IC = ic, OC = oc; + int IW = iw, IH = ih, N = 1; + + // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); + + // Initialize adata + std::vector adata(KW * KH * IC * OC); + for (int i = 0; i < KW * KH * IC * OC; i++) { + adata[i] = 2.5f; + } + + // Convert adata to fp16 format + // std::vector hadata(KW * KH * IC * OC); + // ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); + + // Initialize bdata + std::vector bdata(IW * IH * IC * N); + for (int i = 0; i < IW * IH * IC * N; i++) { + bdata[i] = 1.5f; + } + + size_t buffer_size = 0; + { + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUDA + if (use_gpu) { + // fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); + } +} + +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); + +struct ggml_cgraph * build_graph_0(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(conv2d_res, "conv2d_res"); + ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + // ggml_set_name(wino_res, "wino_res"); + // ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * build_graph_1(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // ggml_set_name(conv2d_res, "conv2d_res"); + // ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(wino_res, "wino_res"); + ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + + + + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, double *t) { + struct ggml_cgraph * gf = build_graph(model); + + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + + + ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_synchronize(model.backend); + + int64_t start_time = ggml_time_us(); + + for(int iter=0; iter data(ggml_nelements(res)); + ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); + + *t = time_us/1000; + return data; + +} + + +int main(void) +{ + ggml_time_init(); + std::vector> configs = { + std::make_tuple(64,64,48,64), + std::make_tuple(320,320,104,152), + std::make_tuple(640,640,52,76), + std::make_tuple(640,640,104,152), + std::make_tuple(960,320,104,152), + std::make_tuple(1280,1280,26,38), + std::make_tuple(1280,640,52,76), + std::make_tuple(1920,1280,26,38), + std::make_tuple(2560,1280,26,38), + std::make_tuple(512,512,104,152), + std::make_tuple(512,512,208,304), + std::make_tuple(512,256,416,608), + std::make_tuple(256,128,832,1216), + std::make_tuple(256,256,832,1216), + std::make_tuple(320,256,1024,1920) + }; + + int k = 0; + + for (auto c : configs){ + test_model model; + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), true); + + ggml_gallocr_t allocr = NULL; + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph_0(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_0 = NULL; + int iterations = 20; + + double run_time0; + std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); + + ggml_gallocr_free(allocr); + + allocr = NULL; + + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + gf = build_graph_1(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_1 = NULL; + + double run_time1; + std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + + if(k==0) { + k = 1; + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + } + + fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), + run_time0, mem_size0/1024.0f/1024.0f, + run_time1, mem_size1/1024.0f/1024.0f); + + + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + // for(int i = 0; i < 3*28; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } + + ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + + } + + + // printf("\nPerforming test:\n"); + + return 0; +} From 53a2ccbe129472e66a05cd87eee2ed6b3d42a73a Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 24 Sep 2025 21:48:20 -0400 Subject: [PATCH 008/122] minor update and add direct conv in benchmarking --- ggml/src/ggml-cuda/conv2d-implicit.cu | 3 +- tests/test-conv2d-implicit.cpp | 87 ++++++++++++++++++++++++--- 2 files changed, 79 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1e2540f8ca..cae35280c0 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -185,9 +185,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #pragma unroll for (int i = 0; i < 8; ++i){ + auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); #pragma unroll for (int j = 0; j < 8; ++j){ - output_frag[i][j] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; + output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j]; } } } diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index b0efba2f1c..6077299cb4 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -52,8 +52,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu } // Convert adata to fp16 format - // std::vector hadata(KW * KH * IC * OC); - // ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); + std::vector hadata(KW * KH * IC * OC); + ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); // Initialize bdata std::vector bdata(IW * IH * IC * N); @@ -63,7 +63,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu size_t buffer_size = 0; { - buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b buffer_size += 1024; // overhead } @@ -111,7 +112,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu model.ctx = ggml_init(params); // create tensors - model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); // create a allocator @@ -122,9 +123,9 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // load data to buffer if(ggml_backend_is_cpu(model.backend)) { - memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); } else { - ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); } // alloc memory @@ -208,6 +209,48 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // ggml_set_name(conv2d_res, "conv2d_res"); + // ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(wino_res, "wino_res"); + ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * build_graph_2(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + // recalculate for avoid fragmentation // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); // ggml_set_name(conv2d_res, "conv2d_res"); @@ -217,6 +260,7 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(wino_res, "wino_res"); ggml_build_forward_expand(gf, wino_res); // ne = wino_res->ne; @@ -353,16 +397,39 @@ int main(void) double run_time1; std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + + ggml_gallocr_free(allocr); + + allocr = NULL; + + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + gf = build_graph_2(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size2 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_2 = NULL; + + double run_time2; + wino_data = compute_graph(model, allocr, build_graph_2, iterations, &run_time2); + + if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); - fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | direct TIME | direct VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- | --- | --- \n"); } - fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), run_time0, mem_size0/1024.0f/1024.0f, - run_time1, mem_size1/1024.0f/1024.0f); + run_time1, mem_size1/1024.0f/1024.0f, + run_time2, mem_size2/1024.0f/1024.0f); // for(int i = 0; i < ggml_nelements(wino_res); i++) { From c6255442bb56c3123ceec0eefd6a13262ee1bc10 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 8 Oct 2025 13:38:16 -0400 Subject: [PATCH 009/122] minor updates --- tests/test-backend-ops.cpp | 69 ++++++++++++++++++++++++++++++---- tests/test-conv2d-implicit.cpp | 13 ++++--- 2 files changed, 69 insertions(+), 13 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1b9e8a2464..3c4388f8a5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -39,6 +39,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -6725,14 +6726,66 @@ static std::vector> make_test_cases_perf() { } } - for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (auto act_case : cases) { - // Direct CONV_2D - test_cases.emplace_back(new test_conv_2d_implicit( - { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, - { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, - kernel_type, 1, 1, 0, 0, 1, 1, false)); - } + // for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + // for (auto act_case : cases) { + // // Direct CONV_2D + // test_cases.emplace_back(new test_conv_2d_implicit( + // { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + // { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, + // kernel_type, 1, 1, 0, 0, 1, 1, false)); + // } + // } + + // Stable-diffusion layers + std::map idx_sd{ + { "iw", 0 }, + { "ih", 1 }, + { "kw", 2 }, + { "kh", 3 }, + { "Cout", 4 }, + { "Cin", 5 }, + { "B", 6 }, + }; + + // Input image size + uint32_t w = 768; + uint32_t h = 1024; + + // Number of filters (base) + uint32_t Cout_b = 128; + uint32_t Cin_b = 128; + + std::vector> cases_sd = { + { w / 8, h / 8, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x10 (called 10 times) + { w / 4, h / 4, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x7 + { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, // x5 + { w, h, 3, 3, Cout_b, Cin_b, 1 }, // x5 + { w / 8, h / 8, 1, 1, Cout_b * 4, Cin_b * 4, 1 }, // x4 + { w / 8, h / 8, 1, 1, 4, 4, 1 }, + { w / 8, h / 8, 3, 3, Cout_b * 4, 4, 1 }, + + { w / 2, h / 2, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, + { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 4, 1 }, + { w / 2, h / 2, 1, 1, Cout_b * 2, Cin_b * 4, 1 }, + + { w, h, 3, 3, Cout_b, Cin_b * 2, 1 }, + { w, h, 1, 1, Cout_b, Cin_b * 2, 1 }, + { w, h, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, + + { w, h, 3, 3, 3, Cin_b, 1 }, + }; + + for (auto act_case : cases_sd) { + GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + test_cases.emplace_back(new test_conv_2d_implicit( + { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); } diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 6077299cb4..e963e2b361 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -63,8 +63,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu size_t buffer_size = 0; { - // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a - buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b buffer_size += 1024; // overhead } @@ -112,7 +112,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu model.ctx = ggml_init(params); // create tensors - model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); + // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); // create a allocator @@ -123,9 +124,11 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // load data to buffer if(ggml_backend_is_cpu(model.backend)) { - memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + // memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); } else { - ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + // ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); } // alloc memory From 0ca43582e853014a79c71843b86144fe245353a3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 8 Oct 2025 13:52:56 -0400 Subject: [PATCH 010/122] reorder register tile loop --- ggml/src/ggml-cuda/conv2d-implicit.cu | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index cae35280c0..f2af27a7fb 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -183,12 +183,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } +// #pragma unroll +// for (int i = 0; i < 8; ++i){ +// auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); +// #pragma unroll +// for (int j = 0; j < 8; ++j){ +// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j]; +// } +// } #pragma unroll - for (int i = 0; i < 8; ++i){ - auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); + for (int j = 0; j < 8; ++j){ + // auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); #pragma unroll - for (int j = 0; j < 8; ++j){ - output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j]; + for (int i = 0; i < 8; ++i){ + output_frag[j][i] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; } } } @@ -215,7 +223,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (int i = 0; i < 8; ++i){ #pragma unroll for (int j = 0; j < 8; ++j){ - output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; + output_frag[i][j] += ggml_cuda_cast(weight_frag[1][j]) * input_frag[1][i]; } } } @@ -240,15 +248,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #pragma unroll for (int subj = 0; subj < 4; ++subj){ // output sts - smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; + smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj]; } } __syncthreads(); #pragma unroll for (int subk = 0; subk < 16; ++subk){ - int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; - if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32; + if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow) output[outOffset] = smemoutput[output_lds_addr + subk * 32]; } } From 16b0f0ae3c76576bf6f325951f0bb0332ce70f06 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 13 Oct 2025 18:41:30 -0400 Subject: [PATCH 011/122] work in progress --- ggml/src/ggml-cuda/conv2d-implicit.cu | 782 +++++++++++++++++++------ ggml/src/ggml-cuda/conv2d-implicit.cuh | 22 + 2 files changed, 625 insertions(+), 179 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index f2af27a7fb..0a5c370f29 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,172 +1,338 @@ #include "conv2d-implicit.cuh" #include "convert.cuh" -typedef struct{ - unsigned int n; //batch size - unsigned int c; //number if channels - unsigned int h; //height - unsigned int w; //width - unsigned int k; //number of filters - unsigned int r; //filter height - unsigned int s; //filter width - unsigned int u; //stride height - unsigned int v; //stride width - unsigned int p; //padding height - unsigned int q; //padding width - unsigned int d_h; //dilation height - unsigned int d_w; //dilation width - unsigned int Oh; //output height - unsigned int Ow; //output width -} param_t; + +static const int WARPSIZE = 32; // warpSize is not constexpr + +static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + if (row * blockDim.x + col < ncols) { + for (int i = 0; i < nrows; ++i){ + sum += x[i * ncols + row * blockDim.x + col]; + } + dst[row * blockDim.x + col] = sum; + } +} - -template +template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, const param_t param) { - extern __shared__ unsigned char smem[]; + // __shared__ char smem[4 * (TM*TN*NUM_THREADS <= (BM * BK + BK * (BN+PAD)) ? (BM * BK + BK * (BN+PAD)) : (TM*TN*NUM_THREADS))]; + __shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * BM * BK + sizeof(T)*2*BK * (BN+PAD) ? + sizeof(float)*2*BM*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; + // __shared__ float smeminput[2 * BM * BK]; + // __shared__ float smemweight[2 * BK * (BN+PAD)]; T *smemweight = reinterpret_cast(smem); - float *smeminput = reinterpret_cast(smem + 16 * 1024); + float *smeminput = reinterpret_cast(smem + 2 * BK * (BN+PAD) * sizeof(T)); + + const uint tx = threadIdx.x; + const uint bx = blockIdx.x; + const uint by = blockIdx.y; + + const uint PQ = param.Oh * param.Ow; - int tx = threadIdx.x; - int bx = blockIdx.x; - int by = blockIdx.y; - - // Warp tile - const int lane_id = threadIdx.x & 31; - const int warp_id = threadIdx.x >> 5; - const int mma_tid_x = (lane_id >> 1) % 8; - const int mma_tid_y = (lane_id >> 4) * 2 + (lane_id & 1); - // lds addr - int weight_lds_addr = (warp_id >> 1) * 32 + mma_tid_y * 4; - int input_lds_addr = (warp_id & 1) * 64 + mma_tid_x * 4; + const uint lane_id = tx % WARPSIZE; + const uint warp_id = tx / WARPSIZE; + const int mma_tid_x = warp_id / (BN / WN); //(lane_id / 2) % 8; + const int mma_tid_y = warp_id % (BN / WN); //(lane_id / 16) * 2 + (lane_id % 2); - // int x = bx * 128 + input_lds_addr; - // int y = by * 128 + weight_lds_addr; + // lds addr + // int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; + // int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + + // size of the warp subtile + constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); + constexpr uint WSUBM = WM / WMITER; // 64/2=32 + constexpr uint WSUBN = WN / WNITER; // 32/2=16 + + // Placement of the thread in the warp subtile + // const uint threadIdxInWarp = tx % WARPSIZE; // [0, 31] + const uint threadColInWarp = lane_id % (WSUBN / TN); // i%(16/4) + const uint threadRowInWarp = lane_id / (WSUBN / TN); // i/4 + + // int x = bx * BM + input_lds_addr; + // int y = by * BN + weight_lds_addr; int z = blockIdx.z; - T weight_ldg_reg[4]; - float input_ldg_reg[4]; - - int posh_ori[4]; - int posw_ori[4]; -#pragma unroll - for (int i = 0; i < 4; ++i){ - posh_ori[i] = ((bx * 128 + lane_id + i * 32) / param.Ow) * param.u - param.p; - posw_ori[i] = ((bx * 128 + lane_id + i * 32) % param.Ow) * param.v - param.q; - } - int inOffset = z * param.c * param.h * param.w; - int weiOffset = (by * 128 + (tx >> 3) * 4) * param.c * param.r * param.s; - int inChannelOffset = param.h * param.w; + // float weight_ldg_reg[4]; + // float input_ldg_reg[4]; + // 当前线程处理的数据点在oh、ow上的坐标 + // int posh_ori = ((bx * 128 + tx / 2 ) / param.Ow) * param.u - param.p; + // int posw_ori = ((bx * 128 + tx / 2 ) % param.Ow) * param.v - param.q; + // int posh_ori = fastdiv(bx * BM + tx / 2, param.OW_fastdiv) * param.u - param.p; + // int posw_ori = fastmodulo(bx * BM + tx / 2, param.OW_fastdiv) * param.v - param.q; + + + // int inOffset = (ksplit > 0): z * param.c * param.h * param.w ; + // int weiOffset = (by * BN + tx / 8 * 4) * param.c * param.r * param.s; + int inChannelOffset = param.c * param.w; // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; - // sts addr - int weight_sts_addr = (tx & 7) * 132 + - (tx >> 3) * 4; - int input_sts_addr = (warp_id) * 128 + (lane_id); + // uint ks, start_k; + // if constexpr (ksplit > 0){ + // const uint ks = (weightKOffset + ksplit - 1) / ksplit; + // const uint start_k = z * ks; + // } else { + // const uint ks = weightKOffset; + // const uint start_k = 0; + // } + const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const uint start_k = (ksplit > 0)? z * ks: 0; + const uint end_k = min(start_k + ks, weightKOffset); + + // sts addr + // int weight_sts_addr = (tx % 8) * 132 + + // (tx / 8) * 4; int write_flag = 1; - T weight_frag[2][8]; - float input_frag[2][8]; - float output_frag[8][8]; -#pragma unroll - for (int i = 0; i < 8; ++i){ -#pragma unroll - for (int j = 0; j < 8; ++j){ - output_frag[i][j] = 0; - } - } + T weight_frag[2][WNITER * TN]; + float input_frag[2][WMITER * TM] = {0.f}; + float output_frag[WMITER * TM * WNITER * TN] = {0.f}; +// #pragma unroll +// for (int i = 0; i < 8; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 8; ++j) +// { +// output_frag[i][j] = 0; +// } +// } + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = tx / (BK / 4); + const uint innerColA = tx % (BK / 4); + constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; + // const uint innerRowB = tx / (BN / 4); + // const uint innerColB = tx % (BN / 4); + // constexpr uint rowStrideB = NUM_THREADS / (BN / 4); + // ldg -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (tx % 8 < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ - weight_ldg_reg[i] = kernel[weiOffset + (tx & 7) + i * weightKOffset]; - } - else{ - weight_ldg_reg[i] = (T)0.f; + const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; + for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + if(vec_load){ + // if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){ + if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){ + if constexpr (std::is_same_v(T, float)){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + smemweight[weight_sts_addr + offset + 0] = tmp.x; + smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + }else{ // read 4 halves + // half val[4]; + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + half *val = reinterpret_cast(&tmp); + // val[1] = reinterpret_cast(&tmp.y); + smemweight[weight_sts_addr + offset + 0] = val[0]; + smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i){ + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + }else{ + #pragma unroll + for (int i = 0; i < 4; ++i){ + if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){ + // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4])[0]; + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i]; + } else { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } } } - int curC = (warp_id) / (param.r * param.s); // channel offset - int curR = ((warp_id) % (param.r * param.s)) / param.s; // kernel r offset - int curS = ((warp_id) % (param.r * param.s)) % param.s; // kernel s offset -#pragma unroll - for (int i = 0; i < 4; ++i){ - int curH = posh_ori[i] + curR * param.d_h; // input h - int curW = posw_ori[i] + curS * param.d_w; // input w - int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ - input_ldg_reg[i] = input[inOffset + inOffsetTmp]; - } - else{ - input_ldg_reg[i] = 0.0; + + + // int curC = (tx / 32) / (param.r * param.s); // channel offset + // int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + // int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + + // int curR = (tx % 2) * 4 / (param.s * param.c); // channel offset + // int curS = ((tx % 2) * 4 % (param.s * param.c)) / param.c; // kernel r offset + // int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset + + const uint input_sts_addr = innerRowA + innerColA * BM * 4; + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + int inOffset = n * param.c * param.h * param.w ; + if(vec_load){ + const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[input_sts_addr + offset + 0] = tmp.x; + smeminput[input_sts_addr + offset + BM] = tmp.y; + smeminput[input_sts_addr + offset + 2*BM] = tmp.z; + smeminput[input_sts_addr + offset + 3*BM] = tmp.w; + } else { + #pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[input_sts_addr + offset + i*BM] = 0.f; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i){ + const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + smeminput[input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; + } else { + smeminput[input_sts_addr + offset + i*BM] = 0.f; + } + } } } + // sts - for (int i = 0; i < 4; ++i){ - smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; - } - for (int i = 0; i < 4; ++i){ - smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; - } + // for (int i = 0; i < 4; ++i) + // { + // smemweight[weight_sts_addr + i*132] = weight_ldg_reg[i]; + // } + // for (int i = 0; i < 4; ++i) + // { + // smeminput[input_sts_addr + i * 128] = input_ldg_reg[i]; + // } __syncthreads(); + + // if(tx == 0 && bx == 0 && by == 0 && z == 0){ + // for(int i=0; i < 128; ++i) + // printf("%.2f,", smeminput[i]); + // printf("\n"); + // for(int i=128; i < 256; ++i) + // printf("%.2f,", smeminput[i]); + // printf("\n"); + // } + + // if(tx == 0 && bx == 0 && by == 0 && z == 0){ + // printf("%u, %u, %u, %u \n", innerRowA, innerColA, rowStrideA, weight_sts_addr); + // for(int i=0; i < 16; ++i) + // printf("%f,", smemweight[i]); + // printf("\n"); + // for(int i=0; i < 16; ++i) + // printf("%f,", param.weight[i*param.c*param.r*param.s]); + // printf("\n"); + // } + // lds -#pragma unroll - for (int i = 0; i < 4; ++i){ - weight_frag[0][i] = smemweight[weight_lds_addr + i]; - weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; - } -#pragma unroll - for (int i = 0; i < 4; ++i){ - input_frag[0][i] = smeminput[input_lds_addr + i]; - input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; - } + // int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + const uint input_lds_addr = mma_tid_x * WM; + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; - // main loop - for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ + // int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; + const uint weight_lds_addr = mma_tid_y * WN; + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + + threadColInWarp * TN + i]; + +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// weight_frag[0][i] = smemweight[weight_lds_addr + i]; +// weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; +// } + // if(tx == 0 && bx == 0 && by == 0 && z == 0) + // { + // printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][0], weight_frag[0][1], weight_frag[0][2], weight_frag[0][3]); + // printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][4], weight_frag[0][5], weight_frag[0][6], weight_frag[0][7]); + // } +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// input_frag[0][i] = smeminput[input_lds_addr + i]; +// input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; +// } + + + for (int crs = start_k; crs < end_k; crs += BK) + { // ldg - int weiOffsetTmp = crs + 8 + (tx & 7); -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (weiOffsetTmp < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ - weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; - } - else{ - weight_ldg_reg[i] = (T)0.f; - } - } - curC = (crs + 8 + warp_id) / (param.r * param.s); // channel offset - curR = ((crs + 8 + warp_id) % (param.r * param.s)) / param.s; // kernel r offset - curS = ((crs + 8 + warp_id) % (param.r * param.s)) % param.s; // kernel s offset +// if (by * BN + tx / 2 < param.k && tx % 2 * 4 < param.c * param.r * param.s){ +// float4 tmp = reinterpret_cast(¶m.weight[by * BN + tx / 2 * weightKOffset + tx % 2 * 4 + crs + 8])[0]; +// weight_ldg_reg[0] = tmp.x; +// weight_ldg_reg[1] = tmp.y; +// weight_ldg_reg[2] = tmp.z; +// weight_ldg_reg[3] = tmp.w; +// } else { +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// weight_ldg_reg[i] = 0.0; +// } + // curR = (crs + 8 + tx % 2 * 4) / (param.s * param.c); // channel offset + // curS = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) / param.c; // kernel r offset + // curC = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) % param.c; // kernel s offset +// curR = fastdiv(crs + 8 + (tx % 2) * 4, param.SC_fastdiv); // channel offset +// curS = fastdiv(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset +// curC = fastmodulo(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + +// int curH = posh_ori + curR; // input h +// int curW = posw_ori + curS; // input w +// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h){ +// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + +// // float4 tmp = reinterpret_cast(¶m.input[inOffset + inOffsetTmp])[0]; +// // input_ldg_reg[0] = tmp.x; +// // input_ldg_reg[1] = tmp.y; +// // input_ldg_reg[2] = tmp.z; +// // input_ldg_reg[3] = tmp.w; +// reinterpret_cast(&input_ldg_reg[0])[0] = reinterpret_cast(¶m.input[inOffset + inOffsetTmp])[0]; } else { +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// input_ldg_reg[i] = 0.0; +// } -#pragma unroll - for (int i = 0; i < 4; ++i){ - int curH = posh_ori[i] + curR * param.d_h; // input h - int curW = posw_ori[i] + curS * param.d_w; // input w - int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ - input_ldg_reg[i] = input[inOffset + inOffsetTmp]; - } - else{ - input_ldg_reg[i] = 0.f; - } - } int load_flag = write_flag ^ 1; #pragma unroll - for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ + for (int subcrs = 0; subcrs < BK - 1; ++subcrs) + { +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i]; +// weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i + 16]; +// } #pragma unroll - for (int i = 0; i < 4; ++i){ - weight_frag[(subcrs + 1) & 1][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; - weight_frag[(subcrs + 1) & 1][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; - } - // // compute base pointer once - // T* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[(subcrs + 1) % 2][wSubColIdx * TN + i] = smemweight[load_flag * (BN+PAD) * BK + + (subcrs + 1) * (BN+PAD) + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; + // float* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; // // first 4 values -> weight_frag[...][0..3] // float4 v0 = *reinterpret_cast(base_ptr); @@ -177,92 +343,347 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // // unpack into weight_frag // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][0]) = v0; // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][4]) = v1; -#pragma unroll - for (int i = 0; i < 4; ++i){ - input_frag[(subcrs + 1) & 1][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; - input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; - } - // #pragma unroll -// for (int i = 0; i < 8; ++i){ -// auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); -// #pragma unroll -// for (int j = 0; j < 8; ++j){ -// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j]; -// } +// for (int i = 0; i < 4; ++i) +// { +// input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i]; +// input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i + 32]; // } #pragma unroll - for (int j = 0; j < 8; ++j){ - // auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll - for (int i = 0; i < 8; ++i){ - output_frag[j][i] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; + for (uint i = 0; i < TM; ++i) + input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * BM * BK + + (subcrs + 1) * BM + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; + +// #pragma unroll +// for (int i = 0; i < 8; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 8; ++j) +// { +// output_frag[i][j] += weight_frag[subcrs % 2][i] * input_frag[subcrs % 2][j]; +// } +// } + // execute warptile matmul +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN]); + // if(tx == 0 && bx == 0 && by == 0 && z == 0){ + // printf("subcrs:%d, i:%d, j:%d, %f * %f = %f, acc = %f\n", subcrs, wSubRowIdx * TM + resIdxM, wSubColIdx * TN + resIdxN, + // input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM], + // weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN], + // input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * + // weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN], + // output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + // (wSubColIdx * TN) + resIdxN]); + // } + } + } + } + } + } + // ldg +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + if(vec_load){ + if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){ + if constexpr (std::is_same_v(T, float)){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + } else { + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + half *val = reinterpret_cast(&tmp); + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i) + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + }else{ + #pragma unroll + for (int i = 0; i < 4; ++i){ + if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){ + // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i]; + } else { + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } + } +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + int inOffset = n * param.c * param.h * param.w ; + if(vec_load){ + const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[write_flag * BM * BK + input_sts_addr + offset + 0] = tmp.x; + smeminput[write_flag * BM * BK + input_sts_addr + offset + BM] = tmp.y; + smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z; + smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w; + } else { + #pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i){ + const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; + } else { + smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; + } } } } // sts - for (int i = 0; i < 4; ++i){ - smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; - } - for (int i = 0; i < 4; ++i){ - smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; - } + // for (int i = 0; i < 4; ++i) + // { + // smemweight[write_flag * (BN+4) * 8 + weight_sts_addr + i * (BN+4)] = weight_ldg_reg[i]; + // } + // for (int i = 0; i < 4; ++i) + // { + // smeminput[write_flag * BM * 8 + input_sts_addr + i * BM] = input_ldg_reg[i]; + // } __syncthreads(); write_flag ^= 1; #pragma unroll - for (int i = 0; i < 4; ++i){ - weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; - weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; - } + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll - for (int i = 0; i < 4; ++i){ - input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; - input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; - } + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * BM * BK + + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; #pragma unroll - for (int i = 0; i < 8; ++i){ + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) #pragma unroll - for (int j = 0; j < 8; ++j){ - output_frag[i][j] += ggml_cuda_cast(weight_frag[1][j]) * input_frag[1][i]; + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[(load_flag ^ 1) * (BN+PAD) * BK + + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// weight_frag[0][i] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i]; +// weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i + 16]; +// } +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// input_frag[0][i] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i]; +// input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i + 32]; +// } +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[1][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[1][wSubColIdx * TN + resIdxN]); + } + } } } +// #pragma unroll +// for (int i = 0; i < 8; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 8; ++j) +// { +// output_frag[i][j] += weight_frag[1][i] * input_frag[1][j]; +// } +// } } + // if(tx == 59 && bx == 0 && by == 0 && z == 0){ + // for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){ + // printf("%f,", output_frag[i]); + // if((i+1) % (WNITER * TN) == 0) + // printf("\n"); + // } + // printf("\n"); + // } + // if(tx == 59 && bx == 0 && by == 0 && z == 0){ + // int cnt[3] = {0}; + // float values[3] = {-1.f}; + // for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){ + // for(int j = 0; j < 3; j++){ + // if (output_frag[i] == values[j]){ + // cnt[j]++; + // break; + // } else{ + // if (cnt[j] == 0){ + // values[j] = output_frag[i]; + // cnt[j]++; + // break; + // } + // } + // } + // } + // for(int j = 0; j < 3; j++){ + // if(values[j] != -1.f) + // printf("value: %f, cnt: %d \n", values[j], cnt[j]); + // } + // } + // reuse smem float *smemoutput = reinterpret_cast(smem); + // float *smembias = reinterpret_cast(smem + 16 * 1024); + + // bias ldg/sts + // if (tx < BN) + // { + // smembias[tx] = param.bias[by * BN + tx]; + // } + + // constexpr uint OUTMITER = (TM * TN * WNITER * WMITER * NUM_THREADS) / (2 * BK * (BM + BN)) / OUTNITER; + // const uint WMITER_TM_OUTMITER = WMITER * TM / OUTMITER; + // const uint WNITER_TN_OUTNITER = WNITER * TN / OUTNITER; - uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; - uint32_t output_lds_addr = warp_id * 512 + lane_id; - uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; - uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; +// // uint32_t bias_lds_addr = warp_id / 2 * 32; + +// #pragma unroll +// for (int i = 0; i < 2; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 2; ++j) +// { +// __syncthreads(); + +// #pragma unroll +// for (int subi = 0; subi < 4; ++subi) +// { +// #pragma unroll +// for (int subj = 0; subj < 4; ++subj) +// { +// // output sts +// smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; +// } +// } +// __syncthreads(); + +// #pragma unroll +// for (int subk = 0; subk < 16; ++subk) +// { +// int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; +// if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) +// param.output[outOffset] = smemoutput[output_lds_addr + subk * 32]; +// } +// } +// } + const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + // const uint m_idx = by * BN + mma_tid_y * WN + threadColInWarp * WNITER_TN_OUTNITER; + // const uint n_idx = bx * BM + mma_tid_x * WM + threadRowInWarp * WMITER_TM_OUTMITER; + // const uint output_sts_addr = warp_id * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER * WARPSIZE + + // (threadRowInWarp * (WSUBN / TN) + threadColInWarp) * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER; + const uint output_sts_addr = mma_tid_x * BN / WN * TM * TN * WARPSIZE + mma_tid_y * TM * TN * WARPSIZE + + threadColInWarp * TN * WSUBM + threadRowInWarp * TM; + const uint m_idx = by * BN + mma_tid_y * WN; + const uint n_idx = bx * BM + mma_tid_x * WM; #pragma unroll - for (int i = 0; i < 2; ++i){ + for (int i = 0; i < WMITER; ++i) + { #pragma unroll - for (int j = 0; j < 2; ++j){ + for (int j = 0; j < WNITER; ++j) + { __syncthreads(); + #pragma unroll - for (int subi = 0; subi < 4; ++subi){ + for (int subi = 0; subi < TM; ++subi) + { #pragma unroll - for (int subj = 0; subj < 4; ++subj){ + for (int subj = 0; subj < TN; ++subj) + { // output sts - smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj]; + smemoutput[output_sts_addr + subj * WSUBM + subi] = + output_frag[(i * TM + subi) * (WNITER * TN) + j * TN + subj]; } } __syncthreads(); - #pragma unroll - for (int subk = 0; subk < 16; ++subk){ - int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32; - if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow) - output[outOffset] = smemoutput[output_lds_addr + subk * 32]; + for (int subk = 0; subk < TM * TN; ++subk){ + const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; + const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; + const int n = (ksplit > 0) ? gemm_i / PQ : z; + const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; + if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); + // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; + const uint outOffset = ksplit > 0 ? + z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col : + z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + } } } } } +#define NUM_VARIANTS 6 + +/* + conv_shapes[][0]: ne_input=[384,512,256,1],ne_kernel=[3,3,256,256] + conv_shapes[][1]: ne_input=[96,128,512,1],ne_kernel=[3,3,512,512] + conv_shapes[][2]: ne_input=[192,256,512,1git diff],ne_kernel=[3,3,512,512] +*/ +constexpr static int conv_shapes[][NUM_VARIANTS] = { + { 128, 128, 128 }, // BM + { 256, 128, 256 }, // BN + { 8, 8, 8 }, // BK + { 128, 64, 32 }, // WM + { 32, 32 , 256 }, // WN + { 2, 2, 1 }, // WNITER + { 8, 4, 8 }, // TM + { 8, 4, 4 }, // TN + { 256, 256, 128} // NUM_THREADS +}; + template static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number @@ -324,6 +745,9 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t total = B * OC * OH * OW; param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; + params.SC_fastdiv = init_fastdiv_values(KW*KH); + params.OW_fastdiv = init_fastdiv_values(OW); + params.C_fastdiv = init_fastdiv_values(IC); if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 46161feb3c..4fe6134873 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -1,5 +1,27 @@ #pragma once #include "common.cuh" +typedef struct{ + unsigned int n; //batch size + unsigned int c; //number if channels + unsigned int h; //height + unsigned int w; //width + unsigned int k; //number of filters + unsigned int r; //filter height + unsigned int s; //filter width + unsigned int u; //stride height + unsigned int v; //stride width + unsigned int p; //padding height + unsigned int q; //padding width + unsigned int d_h; //dilation height + unsigned int d_w; //dilation width + unsigned int Oh; //output height + unsigned int Ow; //output width + uint3 SC_fastdiv; + uint3 OW_fastdiv; + uint3 C_fastdiv; +} param_t; + + #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 22377220569f91ae37f45cec42106ad8a969df7d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 14 Oct 2025 11:02:10 -0400 Subject: [PATCH 012/122] added block variants; to be debugged --- ggml/src/ggml-cuda/conv2d-implicit.cu | 156 ++++++++++++++++++------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 3 + tests/test-conv2d-implicit.cpp | 28 ++--- 3 files changed, 131 insertions(+), 56 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 0a5c370f29..0b410a460a 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,8 +1,11 @@ -#include "conv2d-implicit.cuh" +// #include +#include "ggml.h" +#include "common.cuh" #include "convert.cuh" - +#include "conv2d-implicit.cuh" static const int WARPSIZE = 32; // warpSize is not constexpr +typedef unsigned int uint; static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.x; @@ -20,7 +23,8 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri template + // layout: 0, NHWC; 1, NCHW + const int layout, const bool vec_load, const int ksplit, const int PAD=4> static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, @@ -76,7 +80,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // int inOffset = (ksplit > 0): z * param.c * param.h * param.w ; // int weiOffset = (by * BN + tx / 8 * 4) * param.c * param.r * param.s; - int inChannelOffset = param.c * param.w; + int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w; // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; @@ -125,16 +129,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, if(vec_load){ // if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){ if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){ - if constexpr (std::is_same_v(T, float)){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; smemweight[weight_sts_addr + offset + 0] = tmp.x; smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; }else{ // read 4 halves // half val[4]; - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; - half *val = reinterpret_cast(&tmp); + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + const half *val = reinterpret_cast(&tmp); // val[1] = reinterpret_cast(&tmp.y); smemweight[weight_sts_addr + offset + 0] = val[0]; smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; @@ -177,15 +181,28 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; int inOffset = n * param.c * param.h * param.w ; if(vec_load){ - const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - + // const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[input_sts_addr + offset + 0] = tmp.x; smeminput[input_sts_addr + offset + BM] = tmp.y; smeminput[input_sts_addr + offset + 2*BM] = tmp.z; @@ -198,14 +215,27 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - + // const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; smeminput[input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; } else { smeminput[input_sts_addr + offset + i*BM] = 0.f; @@ -398,15 +428,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { if(vec_load){ if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){ - if constexpr (std::is_same_v(T, float)){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; } else { - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; - half *val = reinterpret_cast(&tmp); + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + const half *val = reinterpret_cast(&tmp); smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0]; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1]; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; @@ -437,15 +467,29 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; int inOffset = n * param.c * param.h * param.w ; if(vec_load){ - const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[write_flag * BM * BK + input_sts_addr + offset + 0] = tmp.x; smeminput[write_flag * BM * BK + input_sts_addr + offset + BM] = tmp.y; smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z; @@ -458,14 +502,28 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; } else { smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; @@ -684,26 +742,37 @@ constexpr static int conv_shapes[][NUM_VARIANTS] = { { 256, 256, 128} // NUM_THREADS }; -template +template static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { - int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number - int blocky = (P.k + 127) / 128; // blocky number + + const uint BM = conv_shapes[0][CONV_SHAPE]; + const uint BN = conv_shapes[1][CONV_SHAPE]; + const uint BK = conv_shapes[2][CONV_SHAPE]; + const uint WM = conv_shapes[3][CONV_SHAPE]; + const uint WN = conv_shapes[4][CONV_SHAPE]; + const uint WNITER = conv_shapes[5][CONV_SHAPE]; + const uint TM = conv_shapes[6][CONV_SHAPE]; + const uint TN = conv_shapes[7][CONV_SHAPE]; + const uint NUM_THREADS = conv_shapes[8][CONV_SHAPE]; + int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number + int blocky = (P.k + BN-1) / BN; // blocky number int blockz = P.n; // blockz number - int threadx = CUDA_CONV2D_IMPLICT_BLOCK_SIZE; // threadx number per block + // int threadx = NUM; // threadx number per block int thready = 1; // thready number per block int threadz = 1; // threadz number per block - dim3 thblock(threadx, thready, threadz); + dim3 thblock(NUM_THREADS, thready, threadz); dim3 grid(blockx, blocky, blockz); - int smem_size = 24 * 1024; - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); + // int smem_size = 24 * 1024; + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -745,9 +814,12 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t total = B * OC * OH * OW; param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; - params.SC_fastdiv = init_fastdiv_values(KW*KH); + params.SC_fastdiv = init_fastdiv_values(KW*IC); params.OW_fastdiv = init_fastdiv_values(OW); params.C_fastdiv = init_fastdiv_values(IC); + params.RS_fastdiv = init_fastdiv_values(KW*KH); + params.S_fastdiv = init_fastdiv_values(KW); + params.nchw = false; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 4fe6134873..d2f3cffcc3 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -17,9 +17,12 @@ typedef struct{ unsigned int d_w; //dilation width unsigned int Oh; //output height unsigned int Ow; //output width + bool nchw; uint3 SC_fastdiv; uint3 OW_fastdiv; uint3 C_fastdiv; + uint3 RS_fastdiv; + uint3 S_fastdiv; } param_t; diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index e963e2b361..0ac438a137 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -339,21 +339,21 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(64,64,48,64), - std::make_tuple(320,320,104,152), - std::make_tuple(640,640,52,76), - std::make_tuple(640,640,104,152), - std::make_tuple(960,320,104,152), - std::make_tuple(1280,1280,26,38), - std::make_tuple(1280,640,52,76), - std::make_tuple(1920,1280,26,38), - std::make_tuple(2560,1280,26,38), - std::make_tuple(512,512,104,152), - std::make_tuple(512,512,208,304), + // std::make_tuple(64,64,48,64), + // std::make_tuple(320,320,104,152), + // std::make_tuple(640,640,52,76), + // std::make_tuple(640,640,104,152), + // std::make_tuple(960,320,104,152), + // std::make_tuple(1280,1280,26,38), + // std::make_tuple(1280,640,52,76), + // std::make_tuple(1920,1280,26,38), + // std::make_tuple(2560,1280,26,38), + // std::make_tuple(512,512,104,152), + // std::make_tuple(512,512,208,304), std::make_tuple(512,256,416,608), - std::make_tuple(256,128,832,1216), - std::make_tuple(256,256,832,1216), - std::make_tuple(320,256,1024,1920) + // std::make_tuple(256,128,832,1216), + // std::make_tuple(256,256,832,1216), + // std::make_tuple(320,256,1024,1920) }; int k = 0; From 3e2f722d11e39d9f121829e27252222bed569b46 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 14 Oct 2025 11:12:55 -0400 Subject: [PATCH 013/122] fixed missing dilation --- ggml/src/ggml-cuda/conv2d-implicit.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 0b410a460a..8cd17aff84 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -195,8 +195,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const uint curC = layout == 0 ? cur2 : cur0; const uint curR = layout == 0 ? cur0 : cur1; const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR; // input h - const int curW = posw_ori + curS; // input w + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; int inOffsetTmp = layout == 0 ? @@ -229,8 +229,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const uint curC = layout == 0 ? cur2 : cur0; const uint curR = layout == 0 ? cur0 : cur1; const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR; // input h - const int curW = posw_ori + curS; // input w + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; int inOffsetTmp = layout == 0 ? @@ -482,8 +482,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const uint curR = layout == 0 ? cur0 : cur1; const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR; // input h - const int curW = posw_ori + curS; // input w + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; int inOffsetTmp = layout == 0 ? @@ -517,8 +517,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const uint curR = layout == 0 ? cur0 : cur1; const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR; // input h - const int curW = posw_ori + curS; // input w + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; int inOffsetTmp = layout == 0 ? From b70cca2ea371f5c301f3f6220348d3f80beb94ed Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 14 Oct 2025 14:24:35 -0400 Subject: [PATCH 014/122] add support for both NCHW and NHWC layouts --- ggml/include/ggml.h | 3 +- ggml/src/ggml-cuda/conv2d-implicit.cu | 68 +++++++++++++++++--------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 2 +- ggml/src/ggml.c | 18 +++++-- tests/test-backend-ops.cpp | 41 +++++++++++++++- tests/test-conv2d-implicit.cpp | 30 ++++++------ 6 files changed, 118 insertions(+), 44 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b558b29874..3999acbd4e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1992,7 +1992,8 @@ extern "C" { int p0, // padding dimension 0 int p1, // padding dimension 1 int d0, // dilation dimension 0 - int d1); // dilation dimension 1 + int d1, + int layout); // dilation dimension 1 GGML_API struct ggml_tensor * ggml_conv_3d_direct( diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 8cd17aff84..a1693dcf24 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -731,15 +731,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, conv_shapes[][2]: ne_input=[192,256,512,1git diff],ne_kernel=[3,3,512,512] */ constexpr static int conv_shapes[][NUM_VARIANTS] = { - { 128, 128, 128 }, // BM - { 256, 128, 256 }, // BN - { 8, 8, 8 }, // BK - { 128, 64, 32 }, // WM - { 32, 32 , 256 }, // WN - { 2, 2, 1 }, // WNITER - { 8, 4, 8 }, // TM - { 8, 4, 4 }, // TN - { 256, 256, 128} // NUM_THREADS + { 128, 128, 128, 256 }, // BM + { 256, 128, 256, 128 }, // BN + { 8, 8, 8, 8 }, // BK + { 128, 64, 32, 128 }, // WM + { 32, 32 , 256, 32 }, // WN + { 2, 2, 1, 1 }, // WNITER + { 8, 4, 4, 4 }, // TM + { 8, 4, 8, 8 }, // TN + { 256, 256, 128, 256} // NUM_THREADS }; template @@ -763,16 +763,29 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, dim3 thblock(NUM_THREADS, thready, threadz); dim3 grid(blockx, blocky, blockz); // int smem_size = 24 * 1024; - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); + if(P.c % 4 == 0){ + if(P.layout == 0) + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); + else if(P.layout == 1) + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); + } else{ + if(P.layout == 0) + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); + else if(P.layout == 1) + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); + } } static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -785,8 +798,6 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_ASSERT(ggml_is_contiguous(kernel)); GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); - // same number of input channels - GGML_ASSERT(input->ne[2] == kernel->ne[2]); cudaStream_t st = ctx.stream(); @@ -797,17 +808,30 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int PD_Y = p[3]; // padding_y const int DL_X = p[4]; // dilation_x const int DL_Y = p[5]; // dilation_y + const int LT = p[6]; // layout + GGML_ASSERT(LT == 0 || LT == 1); + + // same number of input channels + GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]); // No cwhn - GGML_ASSERT(p[6] == false); + GGML_ASSERT(p[7] == false); - const int IW = input->ne[0]; // input_w - const int IH = input->ne[1]; // input_h + // const int IW = input->ne[0]; // input_w + // const int IH = input->ne[1]; // input_h + // const int OW = dst->ne[0]; // output_w + // const int OH = dst->ne[1]; // output_h + // const int KW = kernel->ne[0]; // kernel_w + // const int KH = kernel->ne[1]; // kernel_h + // const int IC = input->ne[2]; // input_channels + const int IW = input->ne[LT == 0 ? 1 : 0]; // input_w + const int IH = input->ne[LT == 0 ? 2 : 1]; // input_h const int OW = dst->ne[0]; // output_w const int OH = dst->ne[1]; // output_h - const int KW = kernel->ne[0]; // kernel_w - const int KH = kernel->ne[1]; // kernel_h - const int IC = input->ne[2]; // input_channels + const int KW = kernel->ne[LT == 0 ? 1 : 0]; // kernel_w + const int KH = kernel->ne[LT == 0 ? 2 : 1]; // kernel_h + const int IC = input->ne[LT == 0 ? 0: 2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches @@ -819,7 +843,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * params.C_fastdiv = init_fastdiv_values(IC); params.RS_fastdiv = init_fastdiv_values(KW*KH); params.S_fastdiv = init_fastdiv_values(KW); - params.nchw = false; + params.layout = LT; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index d2f3cffcc3..e46d93ef4f 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -17,7 +17,7 @@ typedef struct{ unsigned int d_w; //dilation width unsigned int Oh; //output height unsigned int Ow; //output width - bool nchw; + unsigned int layout; uint3 SC_fastdiv; uint3 OW_fastdiv; uint3 C_fastdiv; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1c746c687a..7fa97e84de 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4584,7 +4584,9 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm( int p0, // padding dimension 0 int p1, // padding dimension 1 int d0, // dilation dimension 0 - int d1) {// dilation dimension 1 + int d1, + // 0: NHWC, 1:NCHW + int layout) {// dilation dimension 1 GGML_ASSERT(a->ne[2] == b->ne[2]); //GGML_ASSERT(a->type == b->type); @@ -4603,10 +4605,20 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm( ggml_set_op_params_i32(result, 3, p1); ggml_set_op_params_i32(result, 4, d0); ggml_set_op_params_i32(result, 5, d1); + ggml_set_op_params_i32(result, 6, layout); + + struct ggml_tensor *ap, *bp; + if(layout == 0){ + ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3)); + bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3)); + } else{ + ap = a; + bp = b; + } result->op = GGML_OP_CONV_2D_IMPLICIT; - result->src[0] = a; - result->src[1] = b; + result->src[0] = ap; + result->src[1] = bp; return result; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3c4388f8a5..49a5688acd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4268,7 +4268,7 @@ struct test_conv_2d_implicit : public test_case { } ggml_tensor * out = - ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1); + ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn?0:1); ggml_set_name(out, "out"); return out; } @@ -6788,6 +6788,45 @@ static std::vector> make_test_cases_perf() { GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); } + for (auto act_case : cases_sd) { + GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + test_cases.emplace_back(new test_conv_2d_implicit( + { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false)); + } + + for (auto act_case : cases_sd) { + GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + test_cases.emplace_back(new test_conv_2d_implicit( + { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, true)); + } + + for (auto act_case : cases_sd) { + GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + test_cases.emplace_back(new test_conv_2d_implicit( + { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, true)); + } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 0ac438a137..0b8de368d1 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -262,7 +262,7 @@ struct ggml_cgraph * build_graph_2(const test_model& model) { // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 0); // struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(wino_res, "wino_res"); ggml_build_forward_expand(gf, wino_res); @@ -339,20 +339,20 @@ int main(void) { ggml_time_init(); std::vector> configs = { - // std::make_tuple(64,64,48,64), - // std::make_tuple(320,320,104,152), - // std::make_tuple(640,640,52,76), - // std::make_tuple(640,640,104,152), - // std::make_tuple(960,320,104,152), - // std::make_tuple(1280,1280,26,38), - // std::make_tuple(1280,640,52,76), - // std::make_tuple(1920,1280,26,38), - // std::make_tuple(2560,1280,26,38), - // std::make_tuple(512,512,104,152), - // std::make_tuple(512,512,208,304), + std::make_tuple(64,64,48,64), + std::make_tuple(320,320,104,152), + std::make_tuple(640,640,52,76), + std::make_tuple(640,640,104,152), + std::make_tuple(960,320,104,152), + std::make_tuple(1280,1280,26,38), + std::make_tuple(1280,640,52,76), + std::make_tuple(1920,1280,26,38), + std::make_tuple(2560,1280,26,38), + std::make_tuple(512,512,104,152), + std::make_tuple(512,512,208,304), std::make_tuple(512,256,416,608), - // std::make_tuple(256,128,832,1216), - // std::make_tuple(256,256,832,1216), + std::make_tuple(256,128,832,1216), + std::make_tuple(256,256,832,1216), // std::make_tuple(320,256,1024,1920) }; @@ -453,8 +453,6 @@ int main(void) } - // printf("\nPerforming test:\n"); - return 0; } From 3f99818925be4a527c1961b2a441a2fb3e5a2213 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 15 Oct 2025 12:46:46 -0400 Subject: [PATCH 015/122] unroll some loops --- ggml/src/ggml-cuda/conv2d-implicit.cu | 17 ++++++++++++----- tests/test-conv2d-implicit.cpp | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index a1693dcf24..cf917a4148 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -4,8 +4,9 @@ #include "convert.cuh" #include "conv2d-implicit.cuh" -static const int WARPSIZE = 32; // warpSize is not constexpr + typedef unsigned int uint; +constexpr uint WARPSIZE = 32; static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.x; @@ -125,6 +126,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // ldg const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; +#pragma unroll for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { if(vec_load){ // if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){ @@ -174,6 +176,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset const uint input_sts_addr = innerRowA + innerColA * BM * 4; +#pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; @@ -278,14 +281,18 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // lds // int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; const uint input_lds_addr = mma_tid_x * WM; +#pragma unroll for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) +#pragma unroll for (uint i = 0; i < TM; ++i) input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; // int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; const uint weight_lds_addr = mma_tid_y * WN; +#pragma unroll for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll for (uint i = 0; i < TN; ++i) weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; @@ -495,7 +502,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z; smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w; } else { - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i) smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; } @@ -781,11 +788,11 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, } static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -811,7 +818,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int LT = p[6]; // layout GGML_ASSERT(LT == 0 || LT == 1); - + // same number of input channels GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]); // No cwhn diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 0b8de368d1..4d416e748c 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -262,7 +262,7 @@ struct ggml_cgraph * build_graph_2(const test_model& model) { // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 0); + struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 1); // struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(wino_res, "wino_res"); ggml_build_forward_expand(gf, wino_res); From ac77b8d0e00ed8f44369629c2c6a154a5bce192e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 15 Oct 2025 14:07:24 -0400 Subject: [PATCH 016/122] change padding size to 1; added padding to input smem --- ggml/src/ggml-cuda/conv2d-implicit.cu | 42 +++++++++++++-------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index cf917a4148..66af15c167 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -25,15 +25,15 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri template + const int layout, const bool vec_load, const int ksplit, const int PAD=1> static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, const param_t param) { // __shared__ char smem[4 * (TM*TN*NUM_THREADS <= (BM * BK + BK * (BN+PAD)) ? (BM * BK + BK * (BN+PAD)) : (TM*TN*NUM_THREADS))]; - __shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * BM * BK + sizeof(T)*2*BK * (BN+PAD) ? - sizeof(float)*2*BM*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; + __shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * (BM+PAD) * BK + sizeof(T)*2*BK * (BN+PAD) ? + sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; // __shared__ float smeminput[2 * BM * BK]; // __shared__ float smemweight[2 * BK * (BN+PAD)]; T *smemweight = reinterpret_cast(smem); @@ -175,7 +175,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // int curS = ((tx % 2) * 4 % (param.s * param.c)) / param.c; // kernel r offset // int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset - const uint input_sts_addr = innerRowA + innerColA * BM * 4; + const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; @@ -206,14 +206,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, curH * inChannelOffset + curW * param.c + curC: curC * inChannelOffset + curH * param.w + curW; float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; - smeminput[input_sts_addr + offset + 0] = tmp.x; - smeminput[input_sts_addr + offset + BM] = tmp.y; - smeminput[input_sts_addr + offset + 2*BM] = tmp.z; - smeminput[input_sts_addr + offset + 3*BM] = tmp.w; + smeminput[input_sts_addr + offset + 0] = tmp.x; + smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; + smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; + smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; } else { #pragma unroll for (int i = 0; i < 4; ++i) - smeminput[input_sts_addr + offset + i*BM] = 0.f; + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; } } else { #pragma unroll @@ -239,9 +239,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int inOffsetTmp = layout == 0 ? curH * inChannelOffset + curW * param.c + curC: curC * inChannelOffset + curH * param.w + curW; - smeminput[input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; + smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; } else { - smeminput[input_sts_addr + offset + i*BM] = 0.f; + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; } } } @@ -390,8 +390,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll for (uint i = 0; i < TM; ++i) - input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * BM * BK + - (subcrs + 1) * BM + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; + input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * (BM+PAD) * BK + + (subcrs + 1) * (BM+PAD) + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; // #pragma unroll // for (int i = 0; i < 8; ++i) @@ -497,14 +497,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, curH * inChannelOffset + curW * param.c + curC: curC * inChannelOffset + curH * param.w + curW; float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; - smeminput[write_flag * BM * BK + input_sts_addr + offset + 0] = tmp.x; - smeminput[write_flag * BM * BK + input_sts_addr + offset + BM] = tmp.y; - smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z; - smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; } else { #pragma unroll for (int i = 0; i < 4; ++i) - smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; } } else { #pragma unroll @@ -531,9 +531,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int inOffsetTmp = layout == 0 ? curH * inChannelOffset + curW * param.c + curC: curC * inChannelOffset + curH * param.w + curW; - smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; } else { - smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; + smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; } } } @@ -553,7 +553,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll for (uint i = 0; i < TM; ++i) - input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * BM * BK + + input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * (BM+PAD) * BK + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; #pragma unroll for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) From 6a1f8b4d5788b0def68e2edae26f884688380115 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 15 Oct 2025 14:21:04 -0400 Subject: [PATCH 017/122] change padding size back to 4 --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 66af15c167..93ede3efc8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -25,7 +25,7 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri template + const int layout, const bool vec_load, const int ksplit, const int PAD=4> static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, From 15484c9bd66dd972f7990fa75589030c8fc59bd5 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 17 Oct 2025 22:16:16 -0400 Subject: [PATCH 018/122] turn on tests for implicit conv2d --- tests/test-backend-ops.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 49a5688acd..1ffa3cf6e4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6726,15 +6726,15 @@ static std::vector> make_test_cases_perf() { } } - // for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - // for (auto act_case : cases) { - // // Direct CONV_2D - // test_cases.emplace_back(new test_conv_2d_implicit( - // { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, - // { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, - // kernel_type, 1, 1, 0, 0, 1, 1, false)); - // } - // } + for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (auto act_case : cases) { + // Direct CONV_2D + test_cases.emplace_back(new test_conv_2d_implicit( + { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, + kernel_type, 1, 1, 0, 0, 1, 1, false)); + } + } // Stable-diffusion layers std::map idx_sd{ From f0a480cc221aedf106f015410341bb22a740c370 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 21 Oct 2025 15:43:35 -0400 Subject: [PATCH 019/122] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 176 ++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 93ede3efc8..174b9b46ba 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -730,6 +730,182 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } +template +static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, + const half * __restrict__ kernel, + float * __restrict__ output, + const param_t param) { + constexpr unsigned int MMA_M = 16; + constexpr unsigned int MMA_N = 8; + + const unsigned int K = param.c * param.r * param.s; + + // for convenience/readability in index calculations + const unsigned int A_stride = K; + const unsigned int B_stride = N; + const unsigned int CD_stride = N; + + // calculate how many bits of shared memory indices are going to be swizzled, and create masks + constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8); + + // loop bounds, constexpr where possible allows for loop unrolling + constexpr unsigned int mma_tiles_per_warp_k = 4; + constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; + constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; + const unsigned int num_block_tiles_k = K / BK; + + // calculate block/warp indices + const unsigned int block_m = blockIdx.y; + const unsigned int block_n = blockIdx.x; + const unsigned int warp_m = threadIdx.y; + const unsigned int warp_n = threadIdx.x / 32; + + // double buffering + extern __shared__ half shmem[]; + half* A_block_smem = shmem; + half* B_block_smem = &shmem[BM * BK]; + constexpr int BUFFER_SIZE = BM * BK + BK * BN; + + // declare register storage + // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together + uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; + + // convenience cast to half for register storage + half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); + + // accumulators start at 0 + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + acc_register_[mma_m][mma_n][0] = 0; + acc_register_[mma_m][mma_n][1] = 0; + acc_register_[mma_m][mma_n][2] = 0; + acc_register_[mma_m][mma_n][3] = 0; + } + } + + // these register arrays are used to cache values pre-fetched from global memory during the inner loop of the kernel + // the code is nicer if we hard code it for these tile dimensions and number of threads + // since we performing this copy with float4 pointers, for these tile dimensions it works out to be 8 float4s for A and 4 float4s for B + static_assert(BM_dim == 256); + static_assert(BN_dim == 256); + static_assert(BK_dim == 32); + static_assert(NUM_THREADS == 256); + float4 A_gmem_cache_reg[4]; + float4 B_gmem_cache_reg[4]; + + // prefetch the first block tile of A,B into shared memory + half* A_block_gmem = A + (block_m * BM_dim * A_stride); + half* B_block_gmem = B + (block_n * BN_dim); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); + tileMemcpySwizzle(B_block_gmem, B_block_smem, N); + + // construct const pointers to warp tiles for use inside the inner loop + + + int offset_direction = 1; + + for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++) + { + __syncthreads(); + + if (block_k != num_block_tiles_k) + { + half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim); + half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim); + tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); + tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); + } + half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim); + half* B_warp_tile = B_block_smem + (warp_n * WN_dim); + + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); + + // outer product between mma tiles + #pragma unroll + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) + { + #pragma unroll + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + #pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); + } + } + } + + + if (block_k != num_block_tiles_k) + { + // switch smem buffers each iteration + A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction; + B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; + offset_direction = -1 * offset_direction; + + tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + } + } + + ////////////// + // epilogue // + ////////////// + half alpha_ = (half)alpha; + half beta_ = (half)beta; + half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; + + // calculate pointers for this warps C and D tiles + half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); + half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); + half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); + half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); + + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); + ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); + + // scale C by beta + acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; + acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; + acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; + acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; + } + } + + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) + { + half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); + stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); + } + } + +} + + #define NUM_VARIANTS 6 /* From f931ad883f94f4e9a67e6ce89df6dfffd402c3a4 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 21 Oct 2025 17:12:50 -0400 Subject: [PATCH 020/122] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 13 ++- ggml/src/ggml-cuda/conv2d-implicit.cuh | 133 +++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 174b9b46ba..360127b8d5 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -730,6 +730,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + template static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, @@ -793,16 +795,16 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // these register arrays are used to cache values pre-fetched from global memory during the inner loop of the kernel // the code is nicer if we hard code it for these tile dimensions and number of threads // since we performing this copy with float4 pointers, for these tile dimensions it works out to be 8 float4s for A and 4 float4s for B - static_assert(BM_dim == 256); - static_assert(BN_dim == 256); - static_assert(BK_dim == 32); + static_assert(BM == 256); + static_assert(BN == 256); + static_assert(BK == 32); static_assert(NUM_THREADS == 256); float4 A_gmem_cache_reg[4]; float4 B_gmem_cache_reg[4]; // prefetch the first block tile of A,B into shared memory - half* A_block_gmem = A + (block_m * BM_dim * A_stride); - half* B_block_gmem = B + (block_n * BN_dim); + half* A_block_gmem = input + (block_m * BM * A_stride); + half* B_block_gmem = weight + (block_n * BN); tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); tileMemcpySwizzle(B_block_gmem, B_block_smem, N); @@ -905,6 +907,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } +#endif #define NUM_VARIANTS 6 diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index e46d93ef4f..0a5fc4ab6a 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -25,6 +25,139 @@ typedef struct{ uint3 S_fastdiv; } param_t; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel +template +__device__ __forceinline__ void tileMemcpySwizzle( + half* src, + half* dst, + const unsigned int src_stride +) +{ + constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; + + // reinterpret input/output as float4 + float4* src_float4 = reinterpret_cast(src); + float4* dst_float4 = reinterpret_cast(dst); + const unsigned int src_stride_vectorized = src_stride / 8; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // apply swizzle to the dst index + const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); + dst_float4[dst_index] = src_float4[src_index]; + thread_row += ROW_STEP; + } +} + + +// this is a special case of the above for when TILE_COLS == 32 +template +__device__ __forceinline__ void tileMemcpySwizzleA( + half* src, + half* dst, + const unsigned int src_stride +) +{ + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; + + // reinterpret input/output as float4 + float4* src_float4 = reinterpret_cast(src); + float4* dst_float4 = reinterpret_cast(dst); + const unsigned int src_stride_vectorized = src_stride / 8; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // apply swizzle to the dst index + const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + dst_float4[dst_index] = src_float4[src_index]; + thread_row += ROW_STEP; + } +} + +template +__device__ __forceinline__ void tileMemcpyLoad( + half* src, + float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int src_stride +) +{ + // reinterpret input/output as float4 + float4* src_float4 = reinterpret_cast(src); + const unsigned int src_stride_vectorized = src_stride / 8; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + dst_reg[i] = src_float4[src_index]; + thread_row += ROW_STEP; + } +} +#endif #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 1b69ed44c69971d0199108c8e636a5f8e0a65372 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 21 Oct 2025 17:15:26 -0400 Subject: [PATCH 021/122] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 360127b8d5..f646cf73b3 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -805,8 +805,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory half* A_block_gmem = input + (block_m * BM * A_stride); half* B_block_gmem = weight + (block_n * BN); - tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); - tileMemcpySwizzle(B_block_gmem, B_block_smem, N); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); + tileMemcpySwizzle(B_block_gmem, B_block_smem, N); // construct const pointers to warp tiles for use inside the inner loop @@ -819,16 +819,16 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k) { - half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim); - half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim); - tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); - tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); + half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); + half* B_block_gmem = B + (block_k * BK * B_stride) + (block_n * BN); + tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); + tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); } - half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim); - half* B_warp_tile = B_block_smem + (warp_n * WN_dim); + half* A_warp_tile = A_block_smem + (warp_m * WM * BK); + half* B_warp_tile = B_block_smem + (warp_n * WN); - ldmatrix_a(A_warp_tile, A_register_); - ldmatrix_b(B_warp_tile, B_register_); + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); // outer product between mma tiles #pragma unroll @@ -863,8 +863,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; offset_direction = -1 * offset_direction; - tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); - tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); } } From 215ebf6526cbd8ebc0966a8a0d9c2760910a5c20 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 22 Oct 2025 15:56:55 -0400 Subject: [PATCH 022/122] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 247 ++++++++++++++++++++++++- ggml/src/ggml-cuda/conv2d-implicit.cuh | 106 +++++++++-- 2 files changed, 328 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index f646cf73b3..1866517775 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -732,6 +732,236 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +template +__device__ __forceinline__ void ldmatrix_a( + const half* src, + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] +) +{ + static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); + uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][0]), "=r"(reg_[0][0][1]), "=r"(reg_[1][0][0]), "=r"(reg_[1][0][1]) + : "r"(src_addr) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0][0]), "=r"(reg_[2][0][1]), "=r"(reg_[3][0][0]), "=r"(reg_[3][0][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][0][0]), "=r"(reg_[4][0][1]), "=r"(reg_[5][0][0]), "=r"(reg_[5][0][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][0][0]), "=r"(reg_[6][0][1]), "=r"(reg_[7][0][0]), "=r"(reg_[7][0][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b10000; + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + : "r"(src_addr) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b110000; + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + : "r"(src_addr) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + src_addr ^= 0b10000; + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + +} + +template +__device__ __forceinline__ void ldmatrix_b( + const half* src, + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +) +{ + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); + + uint32_t (®_) [4][8] = reinterpret_cast(reg); + const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); + uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "r"(src_addr ^ 0b1000000) + ); + + src_addr += 8 * smem_stride_; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "r"(src_addr ^ 0b1000000) + ); + + src_addr += 8 * smem_stride_; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "r"(src_addr ^ 0b1000000) + ); + + src_addr += 8 * smem_stride_; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "r"(src_addr ^ 0b1000000) + ); + +} + template static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, @@ -742,6 +972,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_N = 8; const unsigned int K = param.c * param.r * param.s; + const uint PQ = param.Oh * param.Ow; + const uint inChannelOffset = param.c * param.w; + const uint weightKOffset = param.c * param.r * param.s; // for convenience/readability in index calculations const unsigned int A_stride = K; @@ -801,12 +1034,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, static_assert(NUM_THREADS == 256); float4 A_gmem_cache_reg[4]; float4 B_gmem_cache_reg[4]; - + // prefetch the first block tile of A,B into shared memory - half* A_block_gmem = input + (block_m * BM * A_stride); - half* B_block_gmem = weight + (block_n * BN); - tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); - tileMemcpySwizzle(B_block_gmem, B_block_smem, N); +// half* A_block_gmem = input + (block_m * BM * A_stride); + half* A_block_gmem = input; + half* B_block_gmem = weight + (block_n * weightKOffset); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); // construct const pointers to warp tiles for use inside the inner loop @@ -864,7 +1098,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, offset_direction = -1 * offset_direction; tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); - tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + tileMemcpySwizzleStoreB (B_gmem_cache_reg, B_block_smem); } } @@ -1026,6 +1260,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; params.SC_fastdiv = init_fastdiv_values(KW*IC); params.OW_fastdiv = init_fastdiv_values(OW); + params.OHOW_fastdiv = init_fastdiv_values(OW*OH); params.C_fastdiv = init_fastdiv_values(IC); params.RS_fastdiv = init_fastdiv_values(KW*KH); params.S_fastdiv = init_fastdiv_values(KW); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 0a5fc4ab6a..9c15d72c8f 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -23,31 +23,70 @@ typedef struct{ uint3 C_fastdiv; uint3 RS_fastdiv; uint3 S_fastdiv; + uint3 OHOW_fastdiv; } param_t; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template -__device__ __forceinline__ void tileMemcpySwizzle( +unsigned int NUM_THREADS> +__device__ __forceinline__ void tileMemcpySwizzleB( half* src, half* dst, const unsigned int src_stride ) { - constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; + // constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; + + // // reinterpret input/output as float4 + // float4* src_float4 = reinterpret_cast(src); + // float4* dst_float4 = reinterpret_cast(dst); + // const unsigned int src_stride_vectorized = src_stride / 8; + + // // # of threads is multiple of # of columns in the tile + // constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + // static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // // flatten out 2d grid of threads into in order of increasing threadIdx.x + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // // assign each thread a row/column in the tile, calculate how many iterations we need + // // to cover the whole tile + // constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + // constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++) + // { + // // apply swizzle to the dst index + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); + // if (thread_col * 8 < param.k && start_k + innerColA * 4 < end_k){ + // float4 tmp = reinterpret_cast(&src[thread_row * src_stride_vectorized + thread_col*8)[0]; + // dst_float4[dst_index] = src_float4[src_index]; + // }else{ // read 4 halves + // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } + + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // float4* src_float4 = reinterpret_cast(src); float4* dst_float4 = reinterpret_cast(dst); - const unsigned int src_stride_vectorized = src_stride / 8; + // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; @@ -57,15 +96,24 @@ __device__ __forceinline__ void tileMemcpySwizzle( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - + // TODO: next block_k loop + const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + const unsigned int src_index = thread_row * src_stride + thread_col * 8; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; - dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); - dst_float4[dst_index] = src_float4[src_index]; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; + }else{ // read 4 halves + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } thread_row += ROW_STEP; } } @@ -77,7 +125,9 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( half* src, half* dst, - const unsigned int src_stride + // const unsigned int src_stride, + const unsigned int inChannelOffset, + param_t param ) { constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; @@ -87,14 +137,13 @@ __device__ __forceinline__ void tileMemcpySwizzleA( constexpr unsigned int TILE_COLS = 32; // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // float4* src_float4 = reinterpret_cast(src); float4* dst_float4 = reinterpret_cast(dst); - const unsigned int src_stride_vectorized = src_stride / 8; + // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; @@ -104,16 +153,35 @@ __device__ __forceinline__ void tileMemcpySwizzleA( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - + + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { + unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + unsigned int inOffset = n * param.c * param.h * param.w; + // TODO: next block_k loop + const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + int curH = posh_ori + curR * param.d_h; // input h + int curW = posw_ori + curS * param.d_w; // input w // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - dst_float4[dst_index] = src_float4[src_index]; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + curR < param.R && curS < param.S && curC < param.c){ + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + } else{ + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } thread_row += ROW_STEP; } } From 66f6d16265cc76984efe8a4b4d739c6ca2a0e0e0 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 23 Oct 2025 13:52:26 -0400 Subject: [PATCH 023/122] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 127 ++++++++------ ggml/src/ggml-cuda/conv2d-implicit.cuh | 232 ++++++++++++++++++++++++- 2 files changed, 301 insertions(+), 58 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1866517775..a11d306c6c 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -892,27 +892,43 @@ __device__ __forceinline__ void ldmatrix_b( static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); - uint32_t (®_) [4][8] = reinterpret_cast(reg); - const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); - unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); +// uint32_t (®_) [4][8] = reinterpret_cast(reg); +// const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); +// unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); +// uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); +// constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + +// asm volatile ( +// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " +// "{%0, %1, %2, %3}, [%4];" +// : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) +// : "r"(src_addr) +// ); + + // 0 asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); - src_addr += 8 * smem_stride_; + src_addr ^= 0b10000; asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " @@ -925,10 +941,12 @@ __device__ __forceinline__ void ldmatrix_b( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); - src_addr += 8 * smem_stride_; +// src_addr += 8 * smem_stride_; + src_addr ^= 0b110000; asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " @@ -941,10 +959,11 @@ __device__ __forceinline__ void ldmatrix_b( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); - src_addr += 8 * smem_stride_; + src_addr ^= 0b10000; asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " @@ -957,7 +976,8 @@ __device__ __forceinline__ void ldmatrix_b( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); } @@ -1038,7 +1058,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory // half* A_block_gmem = input + (block_m * BM * A_stride); half* A_block_gmem = input; - half* B_block_gmem = weight + (block_n * weightKOffset); + half* B_block_gmem = kernel + (block_n * weightKOffset); tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); @@ -1053,16 +1073,17 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k) { - half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); - half* B_block_gmem = B + (block_k * BK * B_stride) + (block_n * BN); - tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); - tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); + // half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); + half* A_block_gmem = input; + half* B_block_gmem = kernel + (block_n * weightKOffset); + tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); + tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); - half* B_warp_tile = B_block_smem + (warp_n * WN); + half* B_warp_tile = B_block_smem + (warp_n * WN * BK); ldmatrix_a(A_warp_tile, A_register_); - ldmatrix_b(B_warp_tile, B_register_); + ldmatrix_b(B_warp_tile, B_register_); // outer product between mma tiles #pragma unroll @@ -1097,47 +1118,47 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; offset_direction = -1 * offset_direction; - tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); - tileMemcpySwizzleStoreB (B_gmem_cache_reg, B_block_smem); + tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); } } ////////////// // epilogue // ////////////// - half alpha_ = (half)alpha; - half beta_ = (half)beta; - half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; +// half alpha_ = (half)alpha; +// half beta_ = (half)beta; +// half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; - // calculate pointers for this warps C and D tiles - half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); - half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); - half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); - half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); +// // calculate pointers for this warps C and D tiles +// half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); +// half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); +// half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); +// half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) - { - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) - { - half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); - ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); +// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) +// { +// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) +// { +// half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); +// ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); - // scale C by beta - acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; - acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; - acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; - acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; - } - } +// // scale C by beta +// acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; +// acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; +// acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; +// acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; +// } +// } - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) - { - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) - { - half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); - stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); - } - } +// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) +// { +// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) +// { +// half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); +// stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); +// } +// } } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 9c15d72c8f..1a54a184a8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -190,15 +190,176 @@ template -__device__ __forceinline__ void tileMemcpyLoad( +__device__ __forceinline__ void tileMemcpyLoadA( half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], - const unsigned int src_stride + // const unsigned int src_stride, + const unsigned int block_k, + const unsigned int inChannelOffset, + param_t param ) { // reinterpret input/output as float4 float4* src_float4 = reinterpret_cast(src); - const unsigned int src_stride_vectorized = src_stride / 8; + // const unsigned int src_stride_vectorized = src_stride / 8; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + // dst_reg[i] = src_float4[src_index]; + // thread_row += ROW_STEP; + unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + unsigned int inOffset = n * param.c * param.h * param.w; + // TODO: next block_k loop + const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + int curH = posh_ori + curR * param.d_h; // input h + int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + curR < param.R && curS < param.S && curC < param.c){ + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + } else{ + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +} + + +template +__device__ __forceinline__ void tileMemcpyLoadB( + half* src, + float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int block_k, + const unsigned int src_stride, + param_t param +) +{ + // reinterpret input/output as float4 + float4* src_float4 = reinterpret_cast(src); + // const unsigned int src_stride_vectorized = src_stride / 8; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + // dst_reg[i] = src_float4[src_index]; + // thread_row += ROW_STEP; + const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; + if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + dst_reg[i] = reinterpret_cast(&src[src_index])[0]; + }else{ // read 4 halves + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +} + +// template +// __device__ __forceinline__ void tileMemcpySwizzleStoreB( +// float4 src_reg[ELEMENTS_PER_THREAD], +// half* dst +// ) +// { +// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; + +// // reinterpret input/output as float4 +// float4* dst_float4 = reinterpret_cast(dst); + +// // # of threads is multiple of # of columns in the tile +// constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; +// static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + +// // flatten out 2d grid of threads into in order of increasing threadIdx.x +// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + +// // assign each thread a row/column in the tile, calculate how many iterations we need +// // to cover the whole tile +// constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; +// constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; +// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; +// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + +// // compile time check that we provided the right amount of registers for storage +// static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + +// #pragma unroll +// for (unsigned int i = 0; i < NUM_ITERS; i++) +// { +// // apply swizzle to the dst index +// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; +// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); +// dst_float4[dst_index] = src_reg[i]; +// thread_row += ROW_STEP; +// } +// } + +// same as above but without the swizzle +template +__device__ __forceinline__ void tileMemcpyStore( + float4 src_reg[ELEMENTS_PER_THREAD], + half* dst, + unsigned int dst_stride_float4 +) +{ + // reinterpret input/output as float4 + float4* dst_float4 = reinterpret_cast(dst); // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; @@ -220,11 +381,72 @@ __device__ __forceinline__ void tileMemcpyLoad( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { - const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; - dst_reg[i] = src_float4[src_index]; + // apply swizzle to the dst index + unsigned int dst_index = thread_row * dst_stride_float4 + thread_col; + dst_float4[dst_index] = src_reg[i]; thread_row += ROW_STEP; } } + +// this is a special case of the above for when TILE_COLS == 32 +template +__device__ __forceinline__ void tileMemcpySwizzleStore( + const float4 (&src_reg)[ELEMENTS_PER_THREAD], + half* dst +) +{ + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; + + // reinterpret input/output as float4 + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // apply swizzle to the dst index + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + dst_float4[dst_index] = src_reg[i]; + thread_row += ROW_STEP; + } +} + +__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { + uint32_t address; + asm("{\n\t" + " .reg .u64 u64addr;\n\t" + " cvta.to.shared.u64 u64addr, %1;\n\t" + " cvt.u32.u64 %0, u64addr;\n\t" + "}" + : "=r"(address) + : "l"(pointer)); + return address; +} + #endif #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 From 2715341c1d4bb2c39b0bd268eca985bad4ba5ae8 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 23 Oct 2025 21:29:45 -0400 Subject: [PATCH 024/122] WIP: output --- ggml/src/ggml-cuda/conv2d-implicit.cu | 47 ++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index a11d306c6c..06bb4c53f1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -986,7 +986,7 @@ template static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, + half * __restrict__ output, const param_t param) { constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; @@ -1123,6 +1123,51 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } + // reuse smem + half *smemoutput = reinterpret_cast(shmem); + const uint lane_id = threadIdx.x % WARPSIZE; + const uint mma_row = lane_id / 4; + const uint mma_col = lane_id % 4; + const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2; + const uint m_idx = by * BN + mma_tid_y * WN; + const uint n_idx = block_m * BM + warp_m * WM; + +#pragma unroll + for (int i = 0; i < 2; ++i) + { + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) + { + // output sts + uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]); + dst_ptr[0] = reg_[0]; + dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]); + dst_ptr[0] = reg_[1]; + } + } + __syncthreads(); +#pragma unroll + const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; + const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; + const int n = fastdiv(gemm_i, param.OHOW_fastdiv); + const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); + // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; + const uint outOffset = ksplit > 0 ? + z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col : + z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + } + } + ////////////// // epilogue // ////////////// From 80a996cfc0019e615af23348072c735988357ee1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 24 Oct 2025 11:41:11 -0400 Subject: [PATCH 025/122] WIP: tensore code compiled ok --- ggml/src/ggml-cuda/conv2d-implicit.cu | 184 ++++++++++++++----------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 33 +++-- 2 files changed, 126 insertions(+), 91 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 06bb4c53f1..482270e2c7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -730,7 +730,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + template __device__ __forceinline__ void ldmatrix_a( @@ -738,6 +738,7 @@ __device__ __forceinline__ void ldmatrix_a( half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] ) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); @@ -880,7 +881,11 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) : "r"(src_addr + 96 * smem_stride_) ); - +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif } template @@ -889,10 +894,11 @@ __device__ __forceinline__ void ldmatrix_b( half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] ) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); -// uint32_t (®_) [4][8] = reinterpret_cast(reg); + uint32_t (®_) [4][8] = reinterpret_cast(reg); // const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); // unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); // uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); @@ -979,15 +985,20 @@ __device__ __forceinline__ void ldmatrix_b( // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); - +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif } template -static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, +static __global__ void conv2d_implicit_kernel_tc(const half * __restrict__ input, const half * __restrict__ kernel, half * __restrict__ output, const param_t param) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; @@ -997,18 +1008,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint weightKOffset = param.c * param.r * param.s; // for convenience/readability in index calculations - const unsigned int A_stride = K; - const unsigned int B_stride = N; - const unsigned int CD_stride = N; +// const unsigned int A_stride = K; +// const unsigned int B_stride = N; +// const unsigned int CD_stride = N; // calculate how many bits of shared memory indices are going to be swizzled, and create masks - constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8); +// constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8); // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; - const unsigned int num_block_tiles_k = K / BK; + const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -1057,8 +1068,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory // half* A_block_gmem = input + (block_m * BM * A_stride); - half* A_block_gmem = input; - half* B_block_gmem = kernel + (block_n * weightKOffset); + const half* A_block_gmem = input; + const half* B_block_gmem = kernel + (block_n * weightKOffset); tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); @@ -1074,10 +1085,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k) { // half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); - half* A_block_gmem = input; - half* B_block_gmem = kernel + (block_n * weightKOffset); - tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); - tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); + const half* A_block_gmem = input; + const half* B_block_gmem = kernel + (block_n * weightKOffset); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); half* B_warp_tile = B_block_smem + (warp_n * WN * BK); @@ -1124,14 +1135,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } // reuse smem - half *smemoutput = reinterpret_cast(shmem); + half *smemoutput = shmem; const uint lane_id = threadIdx.x % WARPSIZE; const uint mma_row = lane_id / 4; const uint mma_col = lane_id % 4; - const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2; const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2; - const uint m_idx = by * BN + mma_tid_y * WN; - const uint n_idx = block_m * BM + warp_m * WM; + const uint m_idx = block_n * BN + warp_n * WN; + const uint n_idx = block_m * BM + warp_m * WM + lane_id; #pragma unroll for (int i = 0; i < 2; ++i) @@ -1142,72 +1153,42 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, { // output sts uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint32_t* dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]); + const uint idx = output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; - dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]); + dst_ptr = reinterpret_cast(&smemoutput[idx + 8 * BN / 2]); dst_ptr[0] = reg_[1]; } } __syncthreads(); + #pragma unroll - const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; - const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; - const int n = fastdiv(gemm_i, param.OHOW_fastdiv); - const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - if (n < param.n && row < param.k && col < param.Oh * param.Ow){ - // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); - // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) - // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; - const uint outOffset = ksplit > 0 ? - z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + - row * param.Oh * param.Ow + col : - z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; - output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + for (int subk = 0; subk < WN / 2; ++subk){ + for (int j = 0; j < 4; ++j){ + const uint row = m_idx + subk + i * WN / 2; + const uint gemm_i = n_idx + j*32; + const int n = fastdiv(gemm_i, param.OHOW_fastdiv); + const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + if(n < param.n && row < param.k && col < param.Oh * param.Ow){ + // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); + // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; + const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2]; + } + } } } - - ////////////// - // epilogue // - ////////////// -// half alpha_ = (half)alpha; -// half beta_ = (half)beta; -// half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; - -// // calculate pointers for this warps C and D tiles -// half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); -// half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); -// half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); -// half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); - -// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) -// { -// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) -// { -// half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); -// ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); - -// // scale C by beta -// acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; -// acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; -// acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; -// acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; -// } -// } - -// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) -// { -// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) -// { -// half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); -// stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); -// } -// } - +#else + GGML_UNUSED(input); + GGML_UNUSED(kernel); + GGML_UNUSED(output); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif } -#endif #define NUM_VARIANTS 6 @@ -1266,11 +1247,53 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, } } -static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) { + constexpr unsigned int BM_dim = 256; + constexpr unsigned int BN_dim = 256; + constexpr unsigned int BK_dim = 32; + + constexpr unsigned int WARPS_PER_BLOCK_M = 2; + constexpr unsigned int WARPS_PER_BLOCK_N = 4; + constexpr unsigned int WARPS_PER_BLOCK_K = 4; + + constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; + constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; + constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; + const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; + constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; + constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; + constexpr unsigned int NumThreads = ThreadsM * ThreadsN; + const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + int id = ggml_cuda_get_device(); + ggml_cuda_pool_alloc x_f16(ctx.pool(id)); + + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(GGML_TYPE_F32); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = P.c * P.h * P.w * P.n; + x_f16.alloc(ne); + to_fp16_cuda(X_D, x_f16.get(), ne, st); + const half *X_H = x_f16.get(); + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + conv2d_implicit_kernel_tc + <<>>(X_H, K_D, Y_H.get(), P); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + }else{ + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + } +#else conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +#endif } -static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { +static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } @@ -1286,6 +1309,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t st = ctx.stream(); + const int cc = ggml_cuda_info().devices[ctx.device].cc; const int32_t * p = (const int32_t *) dst->op_params; const int ST_X = p[0]; // stride_x @@ -1333,8 +1357,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * params.layout = LT; if (kernel->type == GGML_TYPE_F16) { - conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); } else { - conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st); + conv2d_implicit_cuda_f32(ctx, X_D, K_D, Y_D, cc, params, st); } } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 1a54a184a8..b7d2c8ff2e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -31,9 +31,10 @@ typedef struct{ template __device__ __forceinline__ void tileMemcpySwizzleB( - half* src, + const half* src, half* dst, - const unsigned int src_stride + const unsigned int src_stride, + param_t param ) { // constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; @@ -109,7 +110,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -123,7 +124,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( template __device__ __forceinline__ void tileMemcpySwizzleA( - half* src, + const half* src, half* dst, // const unsigned int src_stride, const unsigned int inChannelOffset, @@ -175,7 +176,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.R && curS < param.S && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c){ // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; @@ -191,7 +192,7 @@ unsigned int TILE_COLS, unsigned int NUM_THREADS, unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadA( - half* src, + const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], // const unsigned int src_stride, const unsigned int block_k, @@ -200,7 +201,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( ) { // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // const float4* src_float4 = reinterpret_cast(src); // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile @@ -239,7 +240,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.R && curS < param.S && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c){ // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; @@ -256,7 +257,7 @@ unsigned int TILE_COLS, unsigned int NUM_THREADS, unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadB( - half* src, + const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int block_k, const unsigned int src_stride, @@ -264,7 +265,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( ) { // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // const float4* src_float4 = reinterpret_cast(src); // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile @@ -295,7 +296,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( // dst_reg[i] = src_float4[src_index]; // thread_row += ROW_STEP; const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; - if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -449,5 +450,15 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { #endif +// constexpr unsigned int int_log2(unsigned int x) +// { +// unsigned int result = 0; +// while (x >>= 1) +// { +// result++; +// } +// return result; +// } + #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From be25be8ed3fc7f5ab41a7157d91d388e6363c729 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 24 Oct 2025 14:24:26 -0400 Subject: [PATCH 026/122] WIP: debugging tensor core kernel --- ggml/src/ggml-cuda/conv2d-implicit.cu | 66 +++++++++++++++----------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 28 +++++------ tests/test-conv2d-implicit.cpp | 66 +++++++++++++------------- 3 files changed, 86 insertions(+), 74 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 482270e2c7..f08e19e9fb 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -259,6 +259,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, __syncthreads(); + if(tx == 0 && bx == 0 && by == 0 && z == 0){ + printf("non tensor \n"); + } + // if(tx == 0 && bx == 0 && by == 0 && z == 0){ // for(int i=0; i < 128; ++i) // printf("%.2f,", smeminput[i]); @@ -738,7 +742,7 @@ __device__ __forceinline__ void ldmatrix_a( half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] ) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); @@ -881,11 +885,11 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) : "r"(src_addr + 96 * smem_stride_) ); -#else - GGML_UNUSED(src); - GGML_UNUSED(reg); - NO_DEVICE_CODE; -#endif +// #else +// GGML_UNUSED(src); +// GGML_UNUSED(reg); +// NO_DEVICE_CODE; +// #endif } template @@ -894,7 +898,7 @@ __device__ __forceinline__ void ldmatrix_b( half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] ) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); @@ -985,23 +989,26 @@ __device__ __forceinline__ void ldmatrix_b( // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); -#else - GGML_UNUSED(src); - GGML_UNUSED(reg); - NO_DEVICE_CODE; -#endif +// #else +// GGML_UNUSED(src); +// GGML_UNUSED(reg); +// NO_DEVICE_CODE; +// #endif } template -static __global__ void conv2d_implicit_kernel_tc(const half * __restrict__ input, +static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, half * __restrict__ output, const param_t param) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; +// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y ==0) +// printf("conv2d_implicit_kernel launch BM:%d, BN:%d, BK:%d, WM:%d, WN:%d, WK:%d, NUM_THREADS:%d \n", BM, BN, BK, WM, WN, WK, NUM_THREADS); + const unsigned int K = param.c * param.r * param.s; const uint PQ = param.Oh * param.Ow; const uint inChannelOffset = param.c * param.w; @@ -1180,13 +1187,13 @@ static __global__ void conv2d_implicit_kernel_tc(const half * __restrict__ input } } } -#else - GGML_UNUSED(input); - GGML_UNUSED(kernel); - GGML_UNUSED(output); - GGML_UNUSED(param); - NO_DEVICE_CODE; -#endif +// #else +// GGML_UNUSED(input); +// GGML_UNUSED(kernel); +// GGML_UNUSED(output); +// GGML_UNUSED(param); +// NO_DEVICE_CODE; +// #endif } @@ -1248,8 +1255,8 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, } static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) { +// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; constexpr unsigned int BK_dim = 32; @@ -1267,6 +1274,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); @@ -1280,17 +1290,19 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa to_fp16_cuda(X_D, x_f16.get(), ne, st); const half *X_H = x_f16.get(); ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); - conv2d_implicit_kernel_tc <<>>(X_H, K_D, Y_H.get(), P); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); - }else{ +// #else +// printf("non tensor path called\n"); +// conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); +// #endif + } else{ conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } -#else - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); -#endif + } static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index b7d2c8ff2e..7d966705b8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -26,7 +26,7 @@ typedef struct{ uint3 OHOW_fastdiv; } param_t; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template @@ -98,9 +98,9 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // TODO: next block_k loop - const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) @@ -166,9 +166,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA( int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; // TODO: next block_k loop - const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w // apply swizzle to the dst index @@ -234,9 +234,9 @@ __device__ __forceinline__ void tileMemcpyLoadA( int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; // TODO: next block_k loop - const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && @@ -285,9 +285,9 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) @@ -448,7 +448,7 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { return address; } -#endif +// #endif // constexpr unsigned int int_log2(unsigned int x) // { diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 4d416e748c..3685a10d72 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -63,8 +63,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu size_t buffer_size = 0; { - buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a - // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b buffer_size += 1024; // overhead } @@ -112,8 +112,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu model.ctx = ggml_init(params); // create tensors - // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); - model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); + // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); // create a allocator @@ -124,11 +124,11 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // load data to buffer if(ggml_backend_is_cpu(model.backend)) { - // memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); - memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); } else { - // ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); - ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); } // alloc memory @@ -262,7 +262,7 @@ struct ggml_cgraph * build_graph_2(const test_model& model) { // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 1); + struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 0); // struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(wino_res, "wino_res"); ggml_build_forward_expand(gf, wino_res); @@ -339,20 +339,20 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(64,64,48,64), - std::make_tuple(320,320,104,152), - std::make_tuple(640,640,52,76), - std::make_tuple(640,640,104,152), - std::make_tuple(960,320,104,152), - std::make_tuple(1280,1280,26,38), - std::make_tuple(1280,640,52,76), - std::make_tuple(1920,1280,26,38), - std::make_tuple(2560,1280,26,38), - std::make_tuple(512,512,104,152), - std::make_tuple(512,512,208,304), - std::make_tuple(512,256,416,608), - std::make_tuple(256,128,832,1216), - std::make_tuple(256,256,832,1216), + // std::make_tuple(64,64,48,64), + // std::make_tuple(320,320,104,152), + // std::make_tuple(640,640,52,76), + // std::make_tuple(640,640,104,152), + // std::make_tuple(960,320,104,152), + std::make_tuple(160,1280,26,38), + // std::make_tuple(1280,640,52,76), + // std::make_tuple(1920,1280,26,38), + // std::make_tuple(2560,1280,26,38), + // std::make_tuple(512,512,104,152), + // std::make_tuple(512,512,208,304), + // std::make_tuple(512,256,416,608), + // std::make_tuple(256,128,832,1216), + // std::make_tuple(256,256,832,1216), // std::make_tuple(320,256,1024,1920) }; @@ -375,7 +375,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 20; + int iterations = 0; double run_time0; std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -436,15 +436,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 3*28; i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - // // if(diff > 1.e-4) { - // printf("(%f, %f, %f, %d) \n", - // conv2d_data[i], - // wino_data[i], diff, i); - // // break; - // // } - // } + for(int i = 0; i < 26*38; i++) { + float diff = fabs(conv2d_data[i] - wino_data[i]); + // if(diff > 1.e-4) { + printf("(%f, %f, %f, %d) \n", + conv2d_data[i], + wino_data[i], diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 6c90c20cb1ff375c8b3afb1b0544e088ebb9725c Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 24 Oct 2025 15:33:57 -0400 Subject: [PATCH 027/122] WIP: bug fix --- ggml/src/ggml-cuda/conv2d-implicit.cu | 15 +++++++++++++- ggml/src/ggml-cuda/conv2d-implicit.cuh | 6 ++++-- tests/test-conv2d-implicit.cpp | 27 +++++++++++++------------- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index f08e19e9fb..de2cf4aecb 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1081,7 +1081,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); // construct const pointers to warp tiles for use inside the inner loop - +// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x ==0 && blockIdx.y ==0){ +// for(int i = 0; i < 32; ++i) +// printf("%.2f,", __half2float(A_block_smem[i])); +// printf("\n"); +// } int offset_direction = 1; @@ -1127,6 +1131,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } } + if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ + printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), + __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); + printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(A_register_[0][0][0]), __half2float(A_register_[0][0][1]), + __half2float(A_register_[0][0][2]), __half2float(A_register_[0][0][3])); + printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(B_register_[0][0][0]), __half2float(B_register_[0][0][1]), + __half2float(B_register_[0][0][2]), __half2float(B_register_[0][0][3])); + } if (block_k != num_block_tiles_k) @@ -1141,6 +1153,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } + // reuse smem half *smemoutput = shmem; const uint lane_id = threadIdx.x % WARPSIZE; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 7d966705b8..3ea0461218 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -159,7 +159,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { - unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + // unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; @@ -227,7 +228,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; // dst_reg[i] = src_float4[src_index]; // thread_row += ROW_STEP; - unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + // unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 3685a10d72..58cd74e7a4 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -48,7 +48,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // Initialize adata std::vector adata(KW * KH * IC * OC); for (int i = 0; i < KW * KH * IC * OC; i++) { - adata[i] = 2.5f; + adata[i] = 2.f; } // Convert adata to fp16 format @@ -344,7 +344,7 @@ int main(void) // std::make_tuple(640,640,52,76), // std::make_tuple(640,640,104,152), // std::make_tuple(960,320,104,152), - std::make_tuple(160,1280,26,38), + std::make_tuple(128,1280,26,38), // std::make_tuple(1280,640,52,76), // std::make_tuple(1920,1280,26,38), // std::make_tuple(2560,1280,26,38), @@ -398,7 +398,8 @@ int main(void) struct ggml_cgraph * gf_res_1 = NULL; double run_time1; - std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); ggml_gallocr_free(allocr); @@ -419,7 +420,7 @@ int main(void) struct ggml_cgraph * gf_res_2 = NULL; double run_time2; - wino_data = compute_graph(model, allocr, build_graph_2, iterations, &run_time2); + std::vector wino_data = compute_graph(model, allocr, build_graph_2, iterations, &run_time2); if(k==0) { @@ -436,15 +437,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - for(int i = 0; i < 26*38; i++) { - float diff = fabs(conv2d_data[i] - wino_data[i]); - // if(diff > 1.e-4) { - printf("(%f, %f, %f, %d) \n", - conv2d_data[i], - wino_data[i], diff, i); - // break; - // } - } + // for(int i = 0; i < 26*38; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 24b553204b94bbc1aed4d8e245e69c25ddb40c88 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 24 Oct 2025 16:53:40 -0400 Subject: [PATCH 028/122] WIP: fixed another bug --- ggml/src/ggml-cuda/conv2d-implicit.cu | 36 +++++++++++++++++++++------ tests/test-conv2d-implicit.cpp | 18 +++++++------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index de2cf4aecb..f6059fc3ae 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1131,14 +1131,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } } - if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ - printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), - __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); - printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(A_register_[0][0][0]), __half2float(A_register_[0][0][1]), - __half2float(A_register_[0][0][2]), __half2float(A_register_[0][0][3])); - printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(B_register_[0][0][0]), __half2float(B_register_[0][0][1]), - __half2float(B_register_[0][0][2]), __half2float(B_register_[0][0][3])); - } + // if(threadIdx.x == 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ + // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), + // __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); + // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(A_register_[0][0][0]), __half2float(A_register_[0][0][1]), + // __half2float(A_register_[0][0][2]), __half2float(A_register_[0][0][3])); + // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(B_register_[0][0][0]), __half2float(B_register_[0][0][1]), + // __half2float(B_register_[0][0][2]), __half2float(B_register_[0][0][3])); + // } if (block_k != num_block_tiles_k) @@ -1167,6 +1167,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #pragma unroll for (int i = 0; i < 2; ++i) { + __syncthreads(); + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) @@ -1182,6 +1184,20 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } __syncthreads(); + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x ==0 && blockIdx.y ==0){ + // for(int ii = 0; ii < 128; ++ii) + // printf("%.2f,", __half2float(smemoutput[ii])); + // printf("\n"); + // for(int ii = 128; ii < 256; ++ii) + // printf("%.2f,", __half2float(smemoutput[ii])); + // printf("\n"); + // for(int ii = 0; ii < 128; ++ii) + // printf("%.2f,", __half2float(smemoutput[ii*128])); + // printf("\n"); + // for(int ii = 128; ii < 256; ++ii) + // printf("%.2f,", __half2float(smemoutput[ii*128])); + // printf("\n"); + // } #pragma unroll for (int subk = 0; subk < WN / 2; ++subk){ @@ -1196,6 +1212,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2]; + if(outOffset == 32){ + printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, + n, row, col, __half2float(output[outOffset])); + } } } } diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 58cd74e7a4..19d2826240 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -437,15 +437,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - // // if(diff > 1.e-4) { - // printf("(%f, %f, %f, %d) \n", - // conv2d_data[i], - // wino_data[i], diff, i); - // // break; - // // } - // } + for(int i = 0; i < 26*38; i++) { + float diff = fabs(conv2d_data[i] - wino_data[i]); + // if(diff > 1.e-4) { + printf("(%f, %f, %f, %d) \n", + conv2d_data[i], + wino_data[i], diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 980ddc1e87ac60439d0544a6cf5b5060d1fff6af Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 24 Oct 2025 21:56:58 -0400 Subject: [PATCH 029/122] properly use __CUDA_ARCH__ to protect the tensor path --- ggml/src/ggml-cuda/conv2d-implicit.cu | 57 +++++------ ggml/src/ggml-cuda/conv2d-implicit.cuh | 125 ++++++++----------------- tests/test-conv2d-implicit.cpp | 48 +++++----- 3 files changed, 95 insertions(+), 135 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index f6059fc3ae..000fd89e20 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -742,7 +742,7 @@ __device__ __forceinline__ void ldmatrix_a( half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] ) { -// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); @@ -885,11 +885,11 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) : "r"(src_addr + 96 * smem_stride_) ); -// #else -// GGML_UNUSED(src); -// GGML_UNUSED(reg); -// NO_DEVICE_CODE; -// #endif +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif } template @@ -898,7 +898,7 @@ __device__ __forceinline__ void ldmatrix_b( half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] ) { -// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); @@ -989,11 +989,11 @@ __device__ __forceinline__ void ldmatrix_b( // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); -// #else -// GGML_UNUSED(src); -// GGML_UNUSED(reg); -// NO_DEVICE_CODE; -// #endif +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif } template= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; @@ -1010,7 +1010,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // printf("conv2d_implicit_kernel launch BM:%d, BN:%d, BK:%d, WM:%d, WN:%d, WK:%d, NUM_THREADS:%d \n", BM, BN, BK, WM, WN, WK, NUM_THREADS); const unsigned int K = param.c * param.r * param.s; - const uint PQ = param.Oh * param.Ow; +// const uint PQ = param.Oh * param.Ow; const uint inChannelOffset = param.c * param.w; const uint weightKOffset = param.c * param.r * param.s; @@ -1153,7 +1153,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } - + + // reuse smem half *smemoutput = shmem; const uint lane_id = threadIdx.x % WARPSIZE; @@ -1212,21 +1213,22 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2]; - if(outOffset == 32){ - printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, - n, row, col, __half2float(output[outOffset])); - } + // if(outOffset == 32){ + // printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, + // n, row, col, __half2float(output[outOffset])); + // } } } } } -// #else -// GGML_UNUSED(input); -// GGML_UNUSED(kernel); -// GGML_UNUSED(output); -// GGML_UNUSED(param); -// NO_DEVICE_CODE; -// #endif + +#else + GGML_UNUSED(input); + GGML_UNUSED(kernel); + GGML_UNUSED(output); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif } @@ -1289,7 +1291,8 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) { -// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +// #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + // printf("tensor core path called\n"); constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; constexpr unsigned int BK_dim = 32; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 3ea0461218..69942bffac 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -26,7 +26,7 @@ typedef struct{ uint3 OHOW_fastdiv; } param_t; -// #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template @@ -37,6 +37,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( param_t param ) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING // constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; // // reinterpret input/output as float4 @@ -117,6 +118,13 @@ __device__ __forceinline__ void tileMemcpySwizzleB( } thread_row += ROW_STEP; } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(src_stride); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif } @@ -131,6 +139,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( param_t param ) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; constexpr unsigned int SWIZZLE_BITS_1 = 4; constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; @@ -186,6 +195,13 @@ __device__ __forceinline__ void tileMemcpySwizzleA( } thread_row += ROW_STEP; } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(inChannelOffset); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif } template= GGML_CUDA_TURING // reinterpret input/output as float4 // const float4* src_float4 = reinterpret_cast(src); // const unsigned int src_stride_vectorized = src_stride / 8; @@ -251,6 +268,14 @@ __device__ __forceinline__ void tileMemcpyLoadA( } thread_row += ROW_STEP; } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst_reg); + GGML_UNUSED(block_k); + GGML_UNUSED(inChannelOffset); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif } @@ -266,6 +291,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( param_t param ) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING // reinterpret input/output as float4 // const float4* src_float4 = reinterpret_cast(src); // const unsigned int src_stride_vectorized = src_stride / 8; @@ -305,91 +331,18 @@ __device__ __forceinline__ void tileMemcpyLoadB( } thread_row += ROW_STEP; } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst_reg); + GGML_UNUSED(block_k); + GGML_UNUSED(src_stride); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif } -// template -// __device__ __forceinline__ void tileMemcpySwizzleStoreB( -// float4 src_reg[ELEMENTS_PER_THREAD], -// half* dst -// ) -// { -// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; - -// // reinterpret input/output as float4 -// float4* dst_float4 = reinterpret_cast(dst); - -// // # of threads is multiple of # of columns in the tile -// constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; -// static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - -// // flatten out 2d grid of threads into in order of increasing threadIdx.x -// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; - -// // assign each thread a row/column in the tile, calculate how many iterations we need -// // to cover the whole tile -// constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; -// constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; -// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; -// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - -// // compile time check that we provided the right amount of registers for storage -// static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - -// #pragma unroll -// for (unsigned int i = 0; i < NUM_ITERS; i++) -// { -// // apply swizzle to the dst index -// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; -// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); -// dst_float4[dst_index] = src_reg[i]; -// thread_row += ROW_STEP; -// } -// } // same as above but without the swizzle -template -__device__ __forceinline__ void tileMemcpyStore( - float4 src_reg[ELEMENTS_PER_THREAD], - half* dst, - unsigned int dst_stride_float4 -) -{ - // reinterpret input/output as float4 - float4* dst_float4 = reinterpret_cast(dst); - - // # of threads is multiple of # of columns in the tile - constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; - static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - - // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; - - // assign each thread a row/column in the tile, calculate how many iterations we need - // to cover the whole tile - constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; - constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - - // compile time check that we provided the right amount of registers for storage - static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - - #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++) - { - // apply swizzle to the dst index - unsigned int dst_index = thread_row * dst_stride_float4 + thread_col; - dst_float4[dst_index] = src_reg[i]; - thread_row += ROW_STEP; - } -} // this is a special case of the above for when TILE_COLS == 32 template= GGML_CUDA_TURING constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; constexpr unsigned int SWIZZLE_BITS_1 = 4; constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; @@ -436,6 +390,11 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( dst_float4[dst_index] = src_reg[i]; thread_row += ROW_STEP; } +#else + GGML_UNUSED(src_reg); + GGML_UNUSED(dst); + NO_DEVICE_CODE; +#endif } __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { @@ -450,8 +409,6 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { return address; } -// #endif - // constexpr unsigned int int_log2(unsigned int x) // { // unsigned int result = 0; diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 19d2826240..836bb10637 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -339,20 +339,20 @@ int main(void) { ggml_time_init(); std::vector> configs = { - // std::make_tuple(64,64,48,64), - // std::make_tuple(320,320,104,152), - // std::make_tuple(640,640,52,76), - // std::make_tuple(640,640,104,152), - // std::make_tuple(960,320,104,152), - std::make_tuple(128,1280,26,38), - // std::make_tuple(1280,640,52,76), - // std::make_tuple(1920,1280,26,38), - // std::make_tuple(2560,1280,26,38), - // std::make_tuple(512,512,104,152), - // std::make_tuple(512,512,208,304), - // std::make_tuple(512,256,416,608), - // std::make_tuple(256,128,832,1216), - // std::make_tuple(256,256,832,1216), + std::make_tuple(64,64,48,64), + std::make_tuple(320,320,104,152), + std::make_tuple(640,640,52,76), + std::make_tuple(640,640,104,152), + std::make_tuple(960,320,104,152), + std::make_tuple(1280,1280,26,38), + std::make_tuple(1280,640,52,76), + std::make_tuple(1920,1280,26,38), + std::make_tuple(2560,1280,26,38), + std::make_tuple(512,512,104,152), + std::make_tuple(512,512,208,304), + std::make_tuple(512,256,416,608), + std::make_tuple(256,128,832,1216), + std::make_tuple(256,256,832,1216), // std::make_tuple(320,256,1024,1920) }; @@ -375,7 +375,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 0; + int iterations = 20; double run_time0; std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -437,15 +437,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - for(int i = 0; i < 26*38; i++) { - float diff = fabs(conv2d_data[i] - wino_data[i]); - // if(diff > 1.e-4) { - printf("(%f, %f, %f, %d) \n", - conv2d_data[i], - wino_data[i], diff, i); - // break; - // } - } + // for(int i = 0; i < 26*38; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From c45df12ee7b0755ab9296230d8a3e7ae25bfcab3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 24 Oct 2025 22:40:34 -0400 Subject: [PATCH 030/122] this case is broken; to be debugged --- tests/test-conv2d-implicit.cpp | 49 +++++++++++++++++----------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 836bb10637..f6dfa8c1b4 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -339,20 +339,20 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(64,64,48,64), - std::make_tuple(320,320,104,152), - std::make_tuple(640,640,52,76), - std::make_tuple(640,640,104,152), - std::make_tuple(960,320,104,152), - std::make_tuple(1280,1280,26,38), - std::make_tuple(1280,640,52,76), - std::make_tuple(1920,1280,26,38), - std::make_tuple(2560,1280,26,38), - std::make_tuple(512,512,104,152), - std::make_tuple(512,512,208,304), - std::make_tuple(512,256,416,608), - std::make_tuple(256,128,832,1216), - std::make_tuple(256,256,832,1216), + // std::make_tuple(64,64,48,64), + // std::make_tuple(320,320,104,152), + // std::make_tuple(640,640,52,76), + // std::make_tuple(640,640,104,152), + // std::make_tuple(960,320,104,152), + std::make_tuple(640,128,26,38), + // std::make_tuple(1280,640,52,76), + // std::make_tuple(1920,1280,26,38), + // std::make_tuple(2560,1280,26,38), + // std::make_tuple(512,512,104,152), + // std::make_tuple(512,512,208,304), + // std::make_tuple(512,256,416,608), + // std::make_tuple(256,128,832,1216), + // std::make_tuple(256,256,832,1216), // std::make_tuple(320,256,1024,1920) }; @@ -375,7 +375,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 20; + int iterations = 0; double run_time0; std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -437,15 +437,16 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - // // if(diff > 1.e-4) { - // printf("(%f, %f, %f, %d) \n", - // conv2d_data[i], - // wino_data[i], diff, i); - // // break; - // // } - // } + for(int i = 0; i < 26*38; i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(conv2d_data[i] - wino_data[i]); + // if(diff > 1.e-4) { + printf("(%f, %f, %f, %d) \n", + conv2d_data[i], + wino_data[i], diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 610e41ae2d2754103475ab1e75a5573bb8e8a3ac Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 25 Oct 2025 11:10:39 -0400 Subject: [PATCH 031/122] still debugging --- ggml/src/ggml-cuda/conv2d-implicit.cu | 26 +++++++++++++++++++------- tests/test-conv2d-implicit.cpp | 4 ++-- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 000fd89e20..fa7a905d39 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1130,14 +1130,26 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, ); } } + // if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[3][0][0]), __half2float(acc_register_[3][0][1]), + // __half2float(acc_register_[3][0][2]), __half2float(acc_register_[3][0][3])); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]), + // __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3])); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]), + // __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3])); + // } + // if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ + // printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1])); + // printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1])); + // } } - // if(threadIdx.x == 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ - // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), - // __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); - // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(A_register_[0][0][0]), __half2float(A_register_[0][0][1]), - // __half2float(A_register_[0][0][2]), __half2float(A_register_[0][0][3])); - // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(B_register_[0][0][0]), __half2float(B_register_[0][0][1]), - // __half2float(B_register_[0][0][2]), __half2float(B_register_[0][0][3])); + // if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ + // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(acc_register_[3][0][0]), __half2float(acc_register_[3][0][1]), + // __half2float(acc_register_[3][0][2]), __half2float(acc_register_[3][0][3])); + // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(A_register_[3][0][0]), __half2float(A_register_[3][0][1]), + // __half2float(A_register_[3][0][2]), __half2float(A_register_[3][0][3])); + // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(B_register_[3][0][0]), __half2float(B_register_[3][0][1]), + // __half2float(B_register_[3][0][2]), __half2float(B_register_[3][0][3])); // } diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index f6dfa8c1b4..4b9222a19e 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -48,7 +48,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // Initialize adata std::vector adata(KW * KH * IC * OC); for (int i = 0; i < KW * KH * IC * OC; i++) { - adata[i] = 2.f; + adata[i] = 0.2f; } // Convert adata to fp16 format @@ -344,7 +344,7 @@ int main(void) // std::make_tuple(640,640,52,76), // std::make_tuple(640,640,104,152), // std::make_tuple(960,320,104,152), - std::make_tuple(640,128,26,38), + std::make_tuple(128,128,26,38), // std::make_tuple(1280,640,52,76), // std::make_tuple(1920,1280,26,38), // std::make_tuple(2560,1280,26,38), From 396f55831c97b180329152a7a6c662b45248f24e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 25 Oct 2025 18:14:12 -0400 Subject: [PATCH 032/122] WIP: bug fix --- ggml/src/ggml-cuda/conv2d-implicit.cu | 23 ++++++++++++--------- tests/test-conv2d-implicit.cpp | 29 ++++++++++++++++++--------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index fa7a905d39..6601d160d7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1076,7 +1076,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory // half* A_block_gmem = input + (block_m * BM * A_stride); const half* A_block_gmem = input; - const half* B_block_gmem = kernel + (block_n * weightKOffset); +// const half* B_block_gmem = kernel + (block_n * weightKOffset); + const half* B_block_gmem = kernel + block_n * BN * weightKOffset; tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); @@ -1097,7 +1098,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, { // half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); const half* A_block_gmem = input; - const half* B_block_gmem = kernel + (block_n * weightKOffset); + // const half* B_block_gmem = kernel + (block_n * weightKOffset); + const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); } @@ -1119,6 +1121,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, { asm volatile ( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + // "mma.sync.aligned.m16n8k8.row.row.f16.f16.f16.f16 " "{%0, %1}, " "{%2, %3}, " "{%4}, " @@ -1130,14 +1133,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, ); } } - // if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[3][0][0]), __half2float(acc_register_[3][0][1]), - // __half2float(acc_register_[3][0][2]), __half2float(acc_register_[3][0][3])); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]), - // __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3])); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]), - // __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3])); - // } + if(threadIdx.x == 28 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ + printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), + __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); + printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]), + __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3])); + printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]), + __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3])); + } // if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ // printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1])); // printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1])); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 4b9222a19e..3a5f928ee6 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -42,13 +42,18 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // create data int KW = 3, KH = 3, IC = ic, OC = oc; int IW = iw, IH = ih, N = 1; + srand(time(NULL)); // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); // Initialize adata std::vector adata(KW * KH * IC * OC); for (int i = 0; i < KW * KH * IC * OC; i++) { - adata[i] = 0.2f; + // adata[i] = 2.f; + adata[i] = (float)(i%KW)-1.f; + // adata[i] = (rand() % 255) / 255.0; + // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + // adata[i] = r; } // Convert adata to fp16 format @@ -58,7 +63,11 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // Initialize bdata std::vector bdata(IW * IH * IC * N); for (int i = 0; i < IW * IH * IC * N; i++) { - bdata[i] = 1.5f; + bdata[i] = (float)(i%IW)/10.f; + // bdata[i] = 1.5f; + // bdata[i] = (rand() % 255) / 255.0; + // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + // bdata[i] = r; } size_t buffer_size = 0; @@ -344,7 +353,7 @@ int main(void) // std::make_tuple(640,640,52,76), // std::make_tuple(640,640,104,152), // std::make_tuple(960,320,104,152), - std::make_tuple(128,128,26,38), + std::make_tuple(640,128,26,38), // std::make_tuple(1280,640,52,76), // std::make_tuple(1920,1280,26,38), // std::make_tuple(2560,1280,26,38), @@ -378,7 +387,7 @@ int main(void) int iterations = 0; double run_time0; - std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); + std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); ggml_gallocr_free(allocr); @@ -399,7 +408,7 @@ int main(void) double run_time1; // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); - conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); ggml_gallocr_free(allocr); @@ -439,11 +448,13 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { for(int i = 0; i < 26*38; i++) { // for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(conv2d_data[i] - wino_data[i]); + // float diff = fabs(conv2d_data[i] - wino_data[i]); + float diff = fabs(im2col_data[i] - wino_data[i]); + float diff1 = fabs(im2col_data[i] - conv2d_data[i]); // if(diff > 1.e-4) { - printf("(%f, %f, %f, %d) \n", - conv2d_data[i], - wino_data[i], diff, i); + printf("(%f, %f, %f, %f, %f, %d) \n", + im2col_data[i], conv2d_data[i], + wino_data[i], diff, diff1, i); // break; // } } From 475f9879c5a8c96ef154672004ac727291682a62 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 25 Oct 2025 20:24:14 -0400 Subject: [PATCH 033/122] WIP: fixed another bug --- ggml/src/ggml-cuda/conv2d-implicit.cu | 62 ++++++++++++++++++++------- tests/test-conv2d-implicit.cpp | 14 +++--- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 6601d160d7..d9686ae344 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -931,7 +931,8 @@ __device__ __forceinline__ void ldmatrix_b( asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) // : "r"(src_addr ^ 0b1000000) @@ -941,14 +942,16 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b10000; asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) : "r"(src_addr) ); asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) // : "r"(src_addr ^ 0b1000000) @@ -959,14 +962,16 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b110000; asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) : "r"(src_addr) ); asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) // : "r"(src_addr ^ 0b1000000) @@ -976,14 +981,16 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b10000; asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) : "r"(src_addr) ); asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) // : "r"(src_addr ^ 0b1000000) @@ -1043,6 +1050,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; +// float acc_register_[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; @@ -1131,16 +1139,40 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); + // asm volatile ( + // "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + // "{%0, %1, %2, %3}," + // "{%4, %5}," + // "{%6}," + // "{%7, %8, %9, %10};\n" + // : "=f"(acc_register_[mma_m][mma_n][0]), "=f"(acc_register_[mma_m][mma_n][1]), + // "=f"(acc_register_[mma_m][mma_n][2]), "=f"(acc_register_[mma_m][mma_n][3]) + // : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + // "r"(B_register[mma_k][mma_n]), + // "f"(acc_register_[mma_m][mma_n][0]), "f"(acc_register_[mma_m][mma_n][1]), + // "f"(acc_register_[mma_m][mma_n][2]), "f"(acc_register_[mma_m][mma_n][3]) + // ); } } - if(threadIdx.x == 28 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ - printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), - __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); - printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]), - __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3])); - printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]), - __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3])); - } + // if(threadIdx.x == 12 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), + // __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[0][0][0], acc_register_[0][0][1], + // acc_register_[0][0][2], acc_register_[0][0][3]); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]), + // __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3])); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]), + // __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3])); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[1][0][0], acc_register_[1][0][1], + // acc_register_[1][0][2], acc_register_[1][0][3]); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[1][mma_k][0]), __half2float(A_register_[1][mma_k][1]), + // __half2float(A_register_[1][mma_k][2]), __half2float(A_register_[1][mma_k][3])); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[3][0][0], acc_register_[3][0][1], + // acc_register_[3][0][2], acc_register_[3][0][3]); + // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]), + // __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3])); + // printf(" %d, %d: %f, %f, \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1])); + // } // if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ // printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1])); // printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1])); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 3a5f928ee6..bf18d4ed80 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -50,10 +50,10 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu std::vector adata(KW * KH * IC * OC); for (int i = 0; i < KW * KH * IC * OC; i++) { // adata[i] = 2.f; - adata[i] = (float)(i%KW)-1.f; + // adata[i] = (float)(i%KW)-1.f; // adata[i] = (rand() % 255) / 255.0; - // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - // adata[i] = r; + float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + adata[i] = r; } // Convert adata to fp16 format @@ -63,11 +63,11 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu // Initialize bdata std::vector bdata(IW * IH * IC * N); for (int i = 0; i < IW * IH * IC * N; i++) { - bdata[i] = (float)(i%IW)/10.f; + // bdata[i] = (float)(i%IW)/10.f; // bdata[i] = 1.5f; // bdata[i] = (rand() % 255) / 255.0; - // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - // bdata[i] = r; + float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + bdata[i] = r; } size_t buffer_size = 0; @@ -452,7 +452,7 @@ int main(void) float diff = fabs(im2col_data[i] - wino_data[i]); float diff1 = fabs(im2col_data[i] - conv2d_data[i]); // if(diff > 1.e-4) { - printf("(%f, %f, %f, %f, %f, %d) \n", + printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", im2col_data[i], conv2d_data[i], wino_data[i], diff, diff1, i); // break; From c68fe36ae28716714668560937b82928acb97502 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 25 Oct 2025 21:57:39 -0400 Subject: [PATCH 034/122] WIP: cleanup; enhanced test case --- ggml/src/ggml-cuda/conv2d-implicit.cu | 4 -- tests/test-conv2d-implicit.cpp | 70 ++++++++++++++------------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d9686ae344..cb5d4359dd 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -259,10 +259,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, __syncthreads(); - if(tx == 0 && bx == 0 && by == 0 && z == 0){ - printf("non tensor \n"); - } - // if(tx == 0 && bx == 0 && by == 0 && z == 0){ // for(int i=0; i < 128; ++i) // printf("%.2f,", smeminput[i]); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index bf18d4ed80..f1f7b81686 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -38,9 +38,9 @@ struct test_model { -void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu = false ) { +void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { // create data - int KW = 3, KH = 3, IC = ic, OC = oc; + int KW = kw, KH = kh, IC = ic, OC = oc; int IW = iw, IH = ih, N = 1; srand(time(NULL)); @@ -347,21 +347,24 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr int main(void) { ggml_time_init(); - std::vector> configs = { - // std::make_tuple(64,64,48,64), - // std::make_tuple(320,320,104,152), - // std::make_tuple(640,640,52,76), - // std::make_tuple(640,640,104,152), - // std::make_tuple(960,320,104,152), - std::make_tuple(640,128,26,38), - // std::make_tuple(1280,640,52,76), - // std::make_tuple(1920,1280,26,38), - // std::make_tuple(2560,1280,26,38), - // std::make_tuple(512,512,104,152), - // std::make_tuple(512,512,208,304), - // std::make_tuple(512,256,416,608), - // std::make_tuple(256,128,832,1216), - // std::make_tuple(256,256,832,1216), + std::vector> configs = { + std::make_tuple(64,64,48,64,3,3), + std::make_tuple(320,320,104,152,3,3), + std::make_tuple(640,640,52,76,3,3), + std::make_tuple(640,640,104,152,3,3), + std::make_tuple(960,320,104,152,3,3), + std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,1,1), + std::make_tuple(256,128,768,1024,3,3), + std::make_tuple(256,128,768,1024,1,1), + std::make_tuple(1280,640,52,76,3,3), + std::make_tuple(1920,1280,26,38,3,3), + std::make_tuple(2560,1280,26,38,3,3), + std::make_tuple(512,512,104,152,3,3), + std::make_tuple(512,512,208,304,3,3), + std::make_tuple(512,256,416,608,3,3), + std::make_tuple(256,128,832,1216,3,3), + std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) }; @@ -369,7 +372,8 @@ int main(void) for (auto c : configs){ test_model model; - load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), true); + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), + std::get<3>(c), std::get<4>(c), std::get<5>(c), true); ggml_gallocr_t allocr = NULL; allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); @@ -384,7 +388,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 0; + int iterations = 20; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -438,26 +442,26 @@ int main(void) fprintf(stderr, "| --- | --- | --- | --- | --- | --- | --- \n"); } - fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", - std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), + fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), run_time0, mem_size0/1024.0f/1024.0f, run_time1, mem_size1/1024.0f/1024.0f, run_time2, mem_size2/1024.0f/1024.0f); // for(int i = 0; i < ggml_nelements(wino_res); i++) { - for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - float diff = fabs(im2col_data[i] - wino_data[i]); - float diff1 = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 1.e-4) { - printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - wino_data[i], diff, diff1, i); - // break; - // } - } + // for(int i = 0; i < 26*38; i++) { + // // for(int i = 0; i < conv2d_data.size(); i++) { + // // float diff = fabs(conv2d_data[i] - wino_data[i]); + // float diff = fabs(im2col_data[i] - wino_data[i]); + // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 1.e-4) { + // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // wino_data[i], diff, diff1, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 30990788e8f4c24efa3c4eba0f786098951c76e5 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 27 Oct 2025 08:29:20 -0400 Subject: [PATCH 035/122] WIP --- tests/test-conv2d-implicit.cpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index f1f7b81686..a0e078eaef 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -348,23 +348,23 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(64,64,48,64,3,3), - std::make_tuple(320,320,104,152,3,3), - std::make_tuple(640,640,52,76,3,3), - std::make_tuple(640,640,104,152,3,3), - std::make_tuple(960,320,104,152,3,3), - std::make_tuple(1280,1280,26,38,3,3), - std::make_tuple(1280,1280,26,38,1,1), - std::make_tuple(256,128,768,1024,3,3), + // std::make_tuple(64,64,48,64,3,3), + // std::make_tuple(320,320,104,152,3,3), + // std::make_tuple(640,640,52,76,3,3), + // std::make_tuple(640,640,104,152,3,3), + // std::make_tuple(960,320,104,152,3,3), + // std::make_tuple(1280,1280,26,38,3,3), + // std::make_tuple(1280,1280,26,38,1,1), + // std::make_tuple(256,128,768,1024,3,3), std::make_tuple(256,128,768,1024,1,1), - std::make_tuple(1280,640,52,76,3,3), - std::make_tuple(1920,1280,26,38,3,3), - std::make_tuple(2560,1280,26,38,3,3), - std::make_tuple(512,512,104,152,3,3), - std::make_tuple(512,512,208,304,3,3), - std::make_tuple(512,256,416,608,3,3), - std::make_tuple(256,128,832,1216,3,3), - std::make_tuple(256,256,832,1216,3,3), + // std::make_tuple(1280,640,52,76,3,3), + // std::make_tuple(1920,1280,26,38,3,3), + // std::make_tuple(2560,1280,26,38,3,3), + // std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,208,304,3,3), + // std::make_tuple(512,256,416,608,3,3), + // std::make_tuple(256,128,832,1216,3,3), + // std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) }; From cc327f5224d736d17a8e6d610159c34bf8d8c3de Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 27 Oct 2025 11:23:27 -0400 Subject: [PATCH 036/122] added a specialization for cuda copy op when tensor is transposed --- ggml/src/ggml-cuda/cpy.cu | 63 ++++++++++++++++++++++++++++++++++++-- ggml/src/ggml-cuda/cpy.cuh | 5 +++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 8567c3d5a1..08edd6cc3b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -37,6 +37,48 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne cpy_1(cx + x_offset, cdst + dst_offset); } + +template +static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { + + char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; + + const T* src = reinterpret_cast(cx); + T* dst = reinterpret_cast(cdst); + + const int64_t nmat = ne /(ne00 * ne01); + const int64_t n = ne00 * ne01; + // const int64_t n = ne01 * ne02; + int width = gridDim.x * TILE_DIM; + int x = blockIdx.x * TILE_DIM + threadIdx.x; + int y = blockIdx.y * TILE_DIM + threadIdx.y; + int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset + int ty = blockIdx.x * TILE_DIM + threadIdx.y; + + __shared__ T tile[TILE_DIM * TILE_DIM]; + + for(int i = 0; i < BLOCK_NM; ++i){ + const unsigned int imat = blockIdx.z * BLOCK_NM + i; + if(imat < nmat){ + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ + const unsigned int idx = (y+j)*width + x; + if(idx < n) + tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx]; + } + __syncthreads(); + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ + const unsigned int idx = (ty+j)*width + tx; + if(idx < n) + dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j]; + } + } + } +} + static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -143,10 +185,25 @@ static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + if constexpr (std::is_same_v && std::is_same_v || + std::is_same_v && std::is_same_v + ){ + if (ne00 == ne11 && ne01 = ne10 && nb00 == nb11 && nb10 == nb01){ //transpose + dim3 dimGrid( (ne00 + TILE_DIM - 1) / TILE_DIM, + (ne01 + TILE_DIM - 1) / TILE_DIM, + (ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM ); + dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); + cpy_flt_transpose<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + } else{ // other + cpy_flt><<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + } + } else{ + cpy_flt><<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + } } static void ggml_cpy_f32_q8_0_cuda( diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 0bd3c0c6f8..211348b66a 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -2,6 +2,11 @@ #define CUDA_CPY_BLOCK_SIZE 64 +const int TILE_DIM = 32; +const int BLOCK_ROWS = 8; +const int BLOCK_NM = 8; + + void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From a3784e17adbd873eb39723a4090761407a53dfa3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 27 Oct 2025 15:09:03 -0400 Subject: [PATCH 037/122] WIP: debugging cpy transpose --- ggml/src/ggml-cuda/cpy.cu | 60 ++++++++++++++++++++++------------ ggml/src/ggml.c | 18 ++++++++-- tests/test-backend-ops.cpp | 33 +++++++++++-------- tests/test-conv2d-implicit.cpp | 10 +++--- 4 files changed, 81 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 08edd6cc3b..67901d915a 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -39,7 +39,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne template -static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const int ne, +static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { @@ -58,22 +58,31 @@ static __global__ void cpy_flt_transpose(char * cx, char * cdst_direct,, const i int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset int ty = blockIdx.x * TILE_DIM + threadIdx.y; - __shared__ T tile[TILE_DIM * TILE_DIM]; + // __shared__ T tile[TILE_DIM * TILE_DIM]; + __shared__ T tile[TILE_DIM][TILE_DIM]; for(int i = 0; i < BLOCK_NM; ++i){ const unsigned int imat = blockIdx.z * BLOCK_NM + i; if(imat < nmat){ for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ const unsigned int idx = (y+j)*width + x; - if(idx < n) - tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx]; + if(idx < n){ + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + // tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx]; + tile[row][col] = src[imat*n + idx]; + } } __syncthreads(); for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ const unsigned int idx = (ty+j)*width + tx; - if(idx < n) - dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j]; + if(idx < n){ + // const int row = threadIdx.x; + const int col = (threadIdx.y+j) ^ threadIdx.x; + // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j]; + dst[imat*n + idx] = tile[threadIdx.x][col]; + } } } } @@ -180,30 +189,33 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des #endif } -template +template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - if constexpr (std::is_same_v && std::is_same_v || - std::is_same_v && std::is_same_v - ){ - if (ne00 == ne11 && ne01 = ne10 && nb00 == nb11 && nb10 == nb01){ //transpose + if constexpr ((std::is_same_v && std::is_same_v || + std::is_same_v && std::is_same_v) + && transpose){ + // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); + // printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11); + // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose + // if (transpose) { //transpose + // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); dim3 dimGrid( (ne00 + TILE_DIM - 1) / TILE_DIM, (ne01 + TILE_DIM - 1) / TILE_DIM, (ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM ); dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); - cpy_flt_transpose<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); - } else{ // other - cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); - } - } else{ + cpy_flt_transpose<<>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + } else{ // other cpy_flt><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } + // } else{ + // cpy_flt><<>> + // (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); + // } } static void ggml_cpy_f32_q8_0_cuda( @@ -389,7 +401,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if(src1->op_params[10] == 999){ + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { @@ -420,7 +436,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if(src1->op_params[10] == 999){ + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7fa97e84de..2257847f0f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3301,6 +3301,9 @@ static struct ggml_tensor * ggml_cont_impl( result->op = GGML_OP_CONT; result->src[0] = a; + if (a->op == GGML_OP_TRANSPOSE) { + result->op_params[10] = a->op_params[10]; // preserve the original order + } return result; } @@ -3614,6 +3617,7 @@ struct ggml_tensor * ggml_transpose( result->op = GGML_OP_TRANSPOSE; result->src[0] = a; + result->op_params[10] = 999; // the transpose flag return result; } @@ -4609,8 +4613,18 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm( struct ggml_tensor *ap, *bp; if(layout == 0){ - ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3)); - bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3)); + // ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3)); + // bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3)); + ap = ggml_reshape_4d(ctx, + ggml_cont(ctx, + ggml_transpose(ctx, + ggml_reshape_3d(ctx, a, a->ne[0]*a->ne[1], a->ne[2], a->ne[3]))), + a->ne[2], a->ne[0], a->ne[1], a->ne[3]); + bp = ggml_reshape_4d(ctx, + ggml_cont(ctx, + ggml_transpose(ctx, + ggml_reshape_3d(ctx, b, b->ne[0]*b->ne[1], b->ne[2], b->ne[3]))), + b->ne[2], b->ne[0], b->ne[1], b->ne[3]); } else{ ap = a; bp = b; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1ffa3cf6e4..6c0b5d17a6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2414,6 +2414,7 @@ struct test_cpy : public test_case { const std::array permute_dst; bool _src_use_permute; bool _dst_use_permute; + bool is_transpose; std::string vars() override { return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst); @@ -2430,10 +2431,12 @@ struct test_cpy : public test_case { test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, std::array ne = {10, 10, 10, 1}, std::array permute_src = {0, 0, 0, 0}, - std::array permute_dst = {0, 0, 0, 0}) + std::array permute_dst = {0, 0, 0, 0}, + bool transpose = false) : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst), _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0), - _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {} + _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0), + is_transpose(transpose) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); @@ -2454,6 +2457,8 @@ struct test_cpy : public test_case { } ggml_tensor * out = ggml_cpy(ctx, src, dst); + if(is_transpose) + dst->op_params[10] = 999; ggml_set_name(out, "out"); return out; @@ -4258,14 +4263,14 @@ struct test_conv_2d_implicit : public test_case { ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); ggml_set_name(kernel, "kernel"); - if (cwhn) { - // change memory layout to channel-most-contiguous (CWHN), - // then permute it back so NE matches the original input - input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); - input = ggml_permute(ctx, input, 2, 0, 1, 3); - kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); - kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); - } + // if (cwhn) { + // // change memory layout to channel-most-contiguous (CWHN), + // // then permute it back so NE matches the original input + // input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); + // input = ggml_permute(ctx, input, 2, 0, 1, 3); + // kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); + // kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); + // } ggml_tensor * out = ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn?0:1); @@ -6831,9 +6836,11 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); + // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, false)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index a0e078eaef..f3f4f91700 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -353,10 +353,10 @@ int main(void) // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), - std::make_tuple(256,128,768,1024,1,1), + // std::make_tuple(256,128,768,1024,1,1), // std::make_tuple(1280,640,52,76,3,3), // std::make_tuple(1920,1280,26,38,3,3), // std::make_tuple(2560,1280,26,38,3,3), @@ -451,16 +451,16 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // // for(int i = 0; i < conv2d_data.size(); i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { // // float diff = fabs(conv2d_data[i] - wino_data[i]); // float diff = fabs(im2col_data[i] - wino_data[i]); // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 1.e-4) { + // if(diff > 0.5) { // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", // im2col_data[i], conv2d_data[i], // wino_data[i], diff, diff1, i); // // break; - // // } + // } // } ggml_free(model.ctx); From 6d1228803784568901c64ac4218c37be6948ab6f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 27 Oct 2025 17:32:03 -0400 Subject: [PATCH 038/122] WIP: fixed a bug in cpy transpos index computation --- ggml/src/ggml-cuda/cpy.cu | 15 ++++++++------- ggml/src/ggml.c | 3 --- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 67901d915a..660c021a43 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -52,7 +52,7 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co const int64_t nmat = ne /(ne00 * ne01); const int64_t n = ne00 * ne01; // const int64_t n = ne01 * ne02; - int width = gridDim.x * TILE_DIM; + int width = ne01; int x = blockIdx.x * TILE_DIM + threadIdx.x; int y = blockIdx.y * TILE_DIM + threadIdx.y; int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset @@ -194,8 +194,8 @@ static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - if constexpr ((std::is_same_v && std::is_same_v || + + if constexpr ((std::is_same_v && std::is_same_v || std::is_same_v && std::is_same_v) && transpose){ // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); @@ -203,12 +203,13 @@ static void ggml_cpy_flt_cuda( // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose // if (transpose) { //transpose // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); - dim3 dimGrid( (ne00 + TILE_DIM - 1) / TILE_DIM, - (ne01 + TILE_DIM - 1) / TILE_DIM, + dim3 dimGrid( (ne01 + TILE_DIM - 1) / TILE_DIM, + (ne00 + TILE_DIM - 1) / TILE_DIM, (ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM ); dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); cpy_flt_transpose<<>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } else{ // other + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; cpy_flt><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } @@ -401,7 +402,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - if(src1->op_params[10] == 999){ + if(src0->op_params[10] == 999){ ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -436,7 +437,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - if(src1->op_params[10] == 999){ + if(src0->op_params[10] == 999){ ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2257847f0f..0e172e7216 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3301,9 +3301,6 @@ static struct ggml_tensor * ggml_cont_impl( result->op = GGML_OP_CONT; result->src[0] = a; - if (a->op == GGML_OP_TRANSPOSE) { - result->op_params[10] = a->op_params[10]; // preserve the original order - } return result; } From 3ea524e9c4bf19514d08eb490e6746255ae11a39 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 27 Oct 2025 23:10:19 -0400 Subject: [PATCH 039/122] WIP: almost working --- ggml/src/ggml-cuda/cpy.cu | 85 +++++--- tests/CMakeLists.txt | 1 + tests/test-backend-ops.cpp | 3 +- tests/test-conv2d-implicit.cpp | 22 +- tests/test-transpose.cpp | 375 +++++++++++++++++++++++++++++++++ 5 files changed, 450 insertions(+), 36 deletions(-) create mode 100644 tests/test-transpose.cpp diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 660c021a43..4405f9e378 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -49,10 +49,11 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co const T* src = reinterpret_cast(cx); T* dst = reinterpret_cast(cdst); - const int64_t nmat = ne /(ne00 * ne01); + const int64_t nmat = ne / (ne00 * ne01); const int64_t n = ne00 * ne01; // const int64_t n = ne01 * ne02; int width = ne01; + int height = ne00; int x = blockIdx.x * TILE_DIM + threadIdx.x; int y = blockIdx.y * TILE_DIM + threadIdx.y; int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset @@ -62,29 +63,65 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co __shared__ T tile[TILE_DIM][TILE_DIM]; for(int i = 0; i < BLOCK_NM; ++i){ - const unsigned int imat = blockIdx.z * BLOCK_NM + i; - if(imat < nmat){ - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - const unsigned int idx = (y+j)*width + x; - if(idx < n){ - const int row = threadIdx.y+j; - const int col = threadIdx.x ^ row; - // tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx]; - tile[row][col] = src[imat*n + idx]; - } - } - __syncthreads(); + __syncthreads(); - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - const unsigned int idx = (ty+j)*width + tx; - if(idx < n){ - // const int row = threadIdx.x; - const int col = (threadIdx.y+j) ^ threadIdx.x; - // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j]; - dst[imat*n + idx] = tile[threadIdx.x][col]; - } + const unsigned int imat = blockIdx.z * BLOCK_NM + i; + if(imat >= nmat) + break; + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ + if(imat < nmat && x < width && y + j < height){ + const unsigned int idx = (y+j)*width + x; + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + // tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx]; + tile[row][col] = src[imat*n + idx]; } } + __syncthreads(); + + + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // printf("BEGIN %d\n", i); + // for(int jj = 0; jj < TILE_DIM; ++jj){ + // for(int ii = 0; ii < TILE_DIM; ++ii) + // printf("%.f, ", tile[jj][ii]); + // printf("]\n"); + // } + // } + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ + + if(imat < nmat && ty + j < width && tx < height){ + const unsigned int idx = (ty+j)*height + tx; + // const int row = threadIdx.x; + const int col = (threadIdx.y+j) ^ threadIdx.x; + // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j]; + dst[imat*n + idx] = tile[threadIdx.x][col]; + // if(imat*n + idx == 4*ne00){ + // printf("DEBUG: (%u, %u, %u, %u, %u), j=%d, tx=%d, ty=%d, imat=%u idx=%u dst[%u]=%.2f, %f\n", + // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, j, tx, ty, + // imat, idx, imat*n + idx, dst[imat*n + idx], tile[threadIdx.x][threadIdx.y + j]); + // } + } + } + // } + } + + if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // for(int j = 0; j < 32; ++j){ + // j = 0; + for(int i = 0; i < 32; ++i) + // printf("%.2f, ", src[j*48+i]); + // printf("%.2f, ", src[j*48+i]); + printf("%.2f, ", __half2float(src[i])); + printf("]\n"); + // } + printf("==============================\n"); + // for(int j = 0; j < 32; ++j){ + for(int i = 0; i < 32; ++i) + printf("%.2f, ", __half2float(dst[i])); + printf("]\n"); + // } } } @@ -195,11 +232,11 @@ static void ggml_cpy_flt_cuda( const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - if constexpr ((std::is_same_v && std::is_same_v || + if constexpr ((std::is_same_v && std::is_same_v || std::is_same_v && std::is_same_v) && transpose){ - // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); - // printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11); + printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); + printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11); // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose // if (transpose) { //transpose // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7ce76f0105..1787e53eb5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -199,6 +199,7 @@ endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-conv2d-implicit.cpp) +llama_build_and_test(test-transpose.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6c0b5d17a6..2016c3f74c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2458,7 +2458,7 @@ struct test_cpy : public test_case { ggml_tensor * out = ggml_cpy(ctx, src, dst); if(is_transpose) - dst->op_params[10] = 999; + src->op_params[10] = 999; ggml_set_name(out, "out"); return out; @@ -6136,6 +6136,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {48, 48, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index f3f4f91700..2790b3c235 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -451,17 +451,17 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // // float diff = fabs(conv2d_data[i] - wino_data[i]); - // float diff = fabs(im2col_data[i] - wino_data[i]); - // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // wino_data[i], diff, diff1, i); - // // break; - // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + float diff = fabs(im2col_data[i] - wino_data[i]); + float diff1 = fabs(im2col_data[i] - conv2d_data[i]); + if(diff > 0.5) { + printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + wino_data[i], diff, diff1, i); + // break; + } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); diff --git a/tests/test-transpose.cpp b/tests/test-transpose.cpp new file mode 100644 index 0000000000..73263f3438 --- /dev/null +++ b/tests/test-transpose.cpp @@ -0,0 +1,375 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +//#include +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + + + +void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { + // create data + int KW = kw, KH = kh, IC = ic, OC = oc; + int IW = iw, IH = ih, N = 1; + srand(time(NULL)); + + // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); + + // Initialize adata + std::vector adata(KW * KH * IC * OC); + for (int i = 0; i < KW * KH * IC * OC; i++) { + // adata[i] = 2.f; + adata[i] = (float)i; + // adata[i] = (rand() % 255) / 255.0; + // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + // adata[i] = r; + } + + // Convert adata to fp16 format + std::vector hadata(KW * KH * IC * OC); + ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); + + // Initialize bdata + std::vector bdata(IW * IH * IC * N); + for (int i = 0; i < IW * IH * IC * N; i++) { + // bdata[i] = (float)(i%IW)/10.f; + // bdata[i] = 1.5f; + bdata[i] = (float)(i+1); + // bdata[i] = (rand() % 255) / 255.0; + // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + // bdata[i] = r; + } + + // for(int i = 0; i < IH; i++) { + // // float diff = fabs(conv2d_data[i] - wino_data[i]); + // for(int j = 0; j < IW; j++) { + // printf("%.0f, ", bdata[i*IW+j]); + // } + // printf("\n"); + // } + for(int i = 0; i < KH; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + for(int j = 0; j < KW; j++) { + printf("%.0f, ", adata[i*KW+j]); + } + printf("\n"); + } + printf(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"); + + size_t buffer_size = 0; + { + // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUDA + if (use_gpu) { + // fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); + // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); + + int64_t *ne = model.a->ne; + printf("before trans: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); + } +} + +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); + +struct ggml_cgraph * build_graph_0(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_cont(ctx0, ggml_transpose(ctx0, model.b)); + struct ggml_tensor* conv2d_res = ggml_cont(ctx0, ggml_transpose(ctx0, model.a)); + ggml_set_name(conv2d_res, "transpose_res"); + ggml_build_forward_expand(gf, conv2d_res); + int64_t *ne = conv2d_res->ne; + printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + // ggml_set_name(wino_res, "wino_res"); + // ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + + + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, double *t) { + struct ggml_cgraph * gf = build_graph(model); + + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + ggml_backend_synchronize(model.backend); + + ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_synchronize(model.backend); + + int64_t start_time = ggml_time_us(); + + for(int iter=0; iter data(ggml_nelements(res)); + std::vector fdata(ggml_nelements(res)); + std::vector data(ggml_nelements(res)); + ggml_backend_tensor_get(res, fdata.data(), 0, ggml_nbytes(res)); + ggml_fp16_to_fp32_row(fdata.data(), data.data(), ggml_nelements(res)); + *t = time_us/1000; + return data; + +} + + +int main(void) +{ + ggml_time_init(); + std::vector> configs = { + // std::make_tuple(64,64,48,64,3,3), + // std::make_tuple(320,320,104,152,3,3), + // std::make_tuple(640,640,52,76,3,3), + // std::make_tuple(640,640,104,152,3,3), + // std::make_tuple(960,320,104,152,3,3), + // std::make_tuple(1,128,38,49,3,3), + std::make_tuple(1,1,38,49,38,49), + // std::make_tuple(1280,1280,26,38,1,1), + // std::make_tuple(256,128,768,1024,3,3), + // std::make_tuple(256,128,768,1024,1,1), + // std::make_tuple(1280,640,52,76,3,3), + // std::make_tuple(1920,1280,26,38,3,3), + // std::make_tuple(2560,1280,26,38,3,3), + // std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,208,304,3,3), + // std::make_tuple(512,256,416,608,3,3), + // std::make_tuple(256,128,832,1216,3,3), + // std::make_tuple(256,256,832,1216,3,3), + // std::make_tuple(320,256,1024,1920) + }; + + int k = 0; + + for (auto c : configs){ + test_model model; + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), + std::get<3>(c), std::get<4>(c), std::get<5>(c), true); + + ggml_gallocr_t allocr = NULL; + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph_0(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_0 = NULL; + int iterations = 0; + + double run_time0; + std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); + + + + + + + + //create the worst case graph for memory usage estimation + + + + + + + + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + // for(int i = 0; i < 26*38; i++) { + // for(int i = 0; i < std::get<2>(c); i++) { + // // float diff = fabs(conv2d_data[i] - wino_data[i]); + // for(int j = 0; j < std::get<3>(c); j++) { + // printf("%4.1f, ", im2col_data[i*std::get<3>(c)+j]); + // } + // printf("\n"); + // } + for(int i = 0; i < std::get<4>(c); i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + for(int j = 0; j < std::get<5>(c); j++) { + printf("%4.1f, ", im2col_data[i*std::get<5>(c)+j]); + } + printf("\n"); + } + + ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + + } + + // printf("\nPerforming test:\n"); + return 0; +} From 75dde410a82f3b7d4686880effaf214c24bd1c78 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 28 Oct 2025 14:41:48 -0400 Subject: [PATCH 040/122] WIP: minor tweak --- ggml/src/ggml-cuda/cpy.cu | 48 +++++++++++++++++--------------------- tests/test-backend-ops.cpp | 9 +++++-- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 4405f9e378..514657537f 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -51,7 +51,6 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co const int64_t nmat = ne / (ne00 * ne01); const int64_t n = ne00 * ne01; - // const int64_t n = ne01 * ne02; int width = ne01; int height = ne00; int x = blockIdx.x * TILE_DIM + threadIdx.x; @@ -59,17 +58,16 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset int ty = blockIdx.x * TILE_DIM + threadIdx.y; - // __shared__ T tile[TILE_DIM * TILE_DIM]; __shared__ T tile[TILE_DIM][TILE_DIM]; for(int i = 0; i < BLOCK_NM; ++i){ - __syncthreads(); const unsigned int imat = blockIdx.z * BLOCK_NM + i; if(imat >= nmat) break; for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - if(imat < nmat && x < width && y + j < height){ + // if(imat < nmat && x < width && y + j < height){ + if(x < width && y + j < height){ const unsigned int idx = (y+j)*width + x; const int row = threadIdx.y+j; const int col = threadIdx.x ^ row; @@ -90,10 +88,9 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co // } for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - - if(imat < nmat && ty + j < width && tx < height){ + // if(imat < nmat && ty + j < width && tx < height){ + if(ty + j < width && tx < height){ const unsigned int idx = (ty+j)*height + tx; - // const int row = threadIdx.x; const int col = (threadIdx.y+j) ^ threadIdx.x; // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j]; dst[imat*n + idx] = tile[threadIdx.x][col]; @@ -104,25 +101,24 @@ static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, co // } } } - // } } - if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // for(int j = 0; j < 32; ++j){ - // j = 0; - for(int i = 0; i < 32; ++i) - // printf("%.2f, ", src[j*48+i]); - // printf("%.2f, ", src[j*48+i]); - printf("%.2f, ", __half2float(src[i])); - printf("]\n"); - // } - printf("==============================\n"); - // for(int j = 0; j < 32; ++j){ - for(int i = 0; i < 32; ++i) - printf("%.2f, ", __half2float(dst[i])); - printf("]\n"); - // } - } + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // // for(int j = 0; j < 32; ++j){ + // // j = 0; + // for(int i = 0; i < 32; ++i) + // // printf("%.2f, ", src[j*48+i]); + // // printf("%.2f, ", src[j*48+i]); + // printf("%.2f, ", __half2float(src[i])); + // printf("]\n"); + // // } + // printf("==============================\n"); + // // for(int j = 0; j < 32; ++j){ + // for(int i = 0; i < 32; ++i) + // printf("%.2f, ", __half2float(dst[i])); + // printf("]\n"); + // // } + // } } static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { @@ -235,8 +231,8 @@ static void ggml_cpy_flt_cuda( if constexpr ((std::is_same_v && std::is_same_v || std::is_same_v && std::is_same_v) && transpose){ - printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); - printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11); + // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); + // printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11); // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose // if (transpose) { //transpose // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2016c3f74c..e564485894 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6840,8 +6840,13 @@ static std::vector> make_test_cases_perf() { // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {1, 0, 2, 3}, false)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, false)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, false)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, false)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); From 4b1920e9e786b93f844e92b2146ff86b463e2d94 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 10:40:52 -0400 Subject: [PATCH 041/122] reduced bank conflicts for output --- ggml/src/ggml-cuda/conv2d-implicit.cu | 8 ++++++-- tests/test-conv2d-implicit.cpp | 25 +++++++++++++------------ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index cb5d4359dd..d793f0de6b 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1219,8 +1219,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, { // output sts uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - const uint idx = output_sts_addr + + uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; dst_ptr = reinterpret_cast(&smemoutput[idx + 8 * BN / 2]); @@ -1255,7 +1256,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; - output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2]; + uint idx = output_lds_addr + subk + j*32*BN/2; + idx = idx ^ ((idx & 0b1110000000) >> 4); + // output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2]; + output[outOffset] = smemoutput[idx]; // if(outOffset == 32){ // printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, // n, row, col, __half2float(output[outOffset])); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 2790b3c235..5ad8f2d274 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -357,6 +357,7 @@ int main(void) // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), // std::make_tuple(256,128,768,1024,1,1), + // std::make_tuple(512,256,384,512,1,1), // std::make_tuple(1280,640,52,76,3,3), // std::make_tuple(1920,1280,26,38,3,3), // std::make_tuple(2560,1280,26,38,3,3), @@ -388,7 +389,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -451,17 +452,17 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - float diff = fabs(im2col_data[i] - wino_data[i]); - float diff1 = fabs(im2col_data[i] - conv2d_data[i]); - if(diff > 0.5) { - printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - wino_data[i], diff, diff1, i); - // break; - } - } + // for(int i = 0; i < conv2d_data.size(); i++) { + // // float diff = fabs(conv2d_data[i] - wino_data[i]); + // float diff = fabs(im2col_data[i] - wino_data[i]); + // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // wino_data[i], diff, diff1, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 1e568252b568fd9b220b2255c348136deca28631 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 12:11:26 -0400 Subject: [PATCH 042/122] switch to default conv2d interface --- ggml/include/ggml.h | 4 ++-- ggml/src/ggml.c | 11 +++++------ tests/test-backend-ops.cpp | 2 +- tests/test-conv2d-implicit.cpp | 8 ++++---- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3999acbd4e..26d6f3332c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1992,8 +1992,8 @@ extern "C" { int p0, // padding dimension 0 int p1, // padding dimension 1 int d0, // dilation dimension 0 - int d1, - int layout); // dilation dimension 1 + int d1); + // int layout); // for future GGML_API struct ggml_tensor * ggml_conv_3d_direct( diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0e172e7216..bfe772697e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4585,9 +4585,9 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm( int p0, // padding dimension 0 int p1, // padding dimension 1 int d0, // dilation dimension 0 - int d1, + int d1){ // 0: NHWC, 1:NCHW - int layout) {// dilation dimension 1 + // int layout) { GGML_ASSERT(a->ne[2] == b->ne[2]); //GGML_ASSERT(a->type == b->type); @@ -4606,12 +4606,10 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm( ggml_set_op_params_i32(result, 3, p1); ggml_set_op_params_i32(result, 4, d0); ggml_set_op_params_i32(result, 5, d1); - ggml_set_op_params_i32(result, 6, layout); struct ggml_tensor *ap, *bp; - if(layout == 0){ - // ap = ggml_cont(ctx, ggml_permute(ctx, a, 1, 2, 0, 3)); - // bp = ggml_cont(ctx, ggml_permute(ctx, b, 1, 2, 0, 3)); + if(a->type == GGML_TYPE_F16 && (a->ne[0] > 1 || a->ne[1] > 1)){ + ggml_set_op_params_i32(result, 6, 0); ap = ggml_reshape_4d(ctx, ggml_cont(ctx, ggml_transpose(ctx, @@ -4623,6 +4621,7 @@ struct ggml_tensor * ggml_conv_2d_implicitgemm( ggml_reshape_3d(ctx, b, b->ne[0]*b->ne[1], b->ne[2], b->ne[3]))), b->ne[2], b->ne[0], b->ne[1], b->ne[3]); } else{ + ggml_set_op_params_i32(result, 6, 1); ap = a; bp = b; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e564485894..e5f80fe909 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4273,7 +4273,7 @@ struct test_conv_2d_implicit : public test_case { // } ggml_tensor * out = - ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn?0:1); + ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1); ggml_set_name(out, "out"); return out; } diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 5ad8f2d274..98b8b0e449 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -262,7 +262,6 @@ struct ggml_cgraph * build_graph_2(const test_model& model) { int d1 = 1; - // recalculate for avoid fragmentation // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); // ggml_set_name(conv2d_res, "conv2d_res"); @@ -271,7 +270,7 @@ struct ggml_cgraph * build_graph_2(const test_model& model) { // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 0); + struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); // struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(wino_res, "wino_res"); ggml_build_forward_expand(gf, wino_res); @@ -353,9 +352,10 @@ int main(void) // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - std::make_tuple(1280,1280,26,38,3,3), + // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), + std::make_tuple(128,3,768,1024,3,3), // std::make_tuple(256,128,768,1024,1,1), // std::make_tuple(512,256,384,512,1,1), // std::make_tuple(1280,640,52,76,3,3), @@ -389,7 +389,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 0; + int iterations = 20; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); From 2dfbbee73f1a3d50b799ab5e8e1c869d0255f16e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 13:19:35 -0400 Subject: [PATCH 043/122] clean up --- ggml/src/ggml-cuda/conv2d-implicit.cu | 529 ++----------------------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 105 +---- tests/test-backend-ops.cpp | 27 -- 3 files changed, 55 insertions(+), 606 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d793f0de6b..6b7efbe789 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -8,6 +8,8 @@ typedef unsigned int uint; constexpr uint WARPSIZE = 32; + +//currently not use; in future for split-k kernels static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.x; const int col = threadIdx.x; @@ -31,11 +33,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, float * __restrict__ output, const param_t param) { - // __shared__ char smem[4 * (TM*TN*NUM_THREADS <= (BM * BK + BK * (BN+PAD)) ? (BM * BK + BK * (BN+PAD)) : (TM*TN*NUM_THREADS))]; __shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * (BM+PAD) * BK + sizeof(T)*2*BK * (BN+PAD) ? - sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; - // __shared__ float smeminput[2 * BM * BK]; - // __shared__ float smemweight[2 * BK * (BN+PAD)]; + sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 2 * BK * (BN+PAD) * sizeof(T)); @@ -48,12 +47,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // Warp tile const uint lane_id = tx % WARPSIZE; const uint warp_id = tx / WARPSIZE; - const int mma_tid_x = warp_id / (BN / WN); //(lane_id / 2) % 8; - const int mma_tid_y = warp_id % (BN / WN); //(lane_id / 16) * 2 + (lane_id % 2); - - // lds addr - // int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; - // int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + const int mma_tid_x = warp_id / (BN / WN); + const int mma_tid_y = warp_id % (BN / WN); // size of the warp subtile constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); @@ -61,75 +56,34 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, constexpr uint WSUBN = WN / WNITER; // 32/2=16 // Placement of the thread in the warp subtile - // const uint threadIdxInWarp = tx % WARPSIZE; // [0, 31] const uint threadColInWarp = lane_id % (WSUBN / TN); // i%(16/4) const uint threadRowInWarp = lane_id / (WSUBN / TN); // i/4 - // int x = bx * BM + input_lds_addr; - // int y = by * BN + weight_lds_addr; int z = blockIdx.z; - - // float weight_ldg_reg[4]; - // float input_ldg_reg[4]; - // 当前线程处理的数据点在oh、ow上的坐标 - // int posh_ori = ((bx * 128 + tx / 2 ) / param.Ow) * param.u - param.p; - // int posw_ori = ((bx * 128 + tx / 2 ) % param.Ow) * param.v - param.q; - // int posh_ori = fastdiv(bx * BM + tx / 2, param.OW_fastdiv) * param.u - param.p; - // int posw_ori = fastmodulo(bx * BM + tx / 2, param.OW_fastdiv) * param.v - param.q; - - - // int inOffset = (ksplit > 0): z * param.c * param.h * param.w ; - // int weiOffset = (by * BN + tx / 8 * 4) * param.c * param.r * param.s; int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w; - // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; - // uint ks, start_k; - - // if constexpr (ksplit > 0){ - // const uint ks = (weightKOffset + ksplit - 1) / ksplit; - // const uint start_k = z * ks; - // } else { - // const uint ks = weightKOffset; - // const uint start_k = 0; - // } const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; const uint start_k = (ksplit > 0)? z * ks: 0; const uint end_k = min(start_k + ks, weightKOffset); - // sts addr - // int weight_sts_addr = (tx % 8) * 132 + - // (tx / 8) * 4; int write_flag = 1; T weight_frag[2][WNITER * TN]; float input_frag[2][WMITER * TM] = {0.f}; float output_frag[WMITER * TM * WNITER * TN] = {0.f}; -// #pragma unroll -// for (int i = 0; i < 8; ++i) -// { -// #pragma unroll -// for (int j = 0; j < 8; ++j) -// { -// output_frag[i][j] = 0; -// } -// } // calculating the indices that this thread will load into SMEM // we'll load 128bit / 32bit = 4 elements per thread at each step const uint innerRowA = tx / (BK / 4); const uint innerColA = tx % (BK / 4); constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; - // const uint innerRowB = tx / (BN / 4); - // const uint innerColB = tx % (BN / 4); - // constexpr uint rowStrideB = NUM_THREADS / (BN / 4); // ldg const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { if(vec_load){ - // if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){ if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){ if constexpr (std::is_same_v){ float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; @@ -138,26 +92,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; }else{ // read 4 halves - // half val[4]; float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; const half *val = reinterpret_cast(&tmp); - // val[1] = reinterpret_cast(&tmp.y); smemweight[weight_sts_addr + offset + 0] = val[0]; smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i){ smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; } } }else{ - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i){ if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){ - // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4])[0]; smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i]; } else { smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; @@ -167,14 +118,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } - // int curC = (tx / 32) / (param.r * param.s); // channel offset - // int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset - // int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset - - // int curR = (tx % 2) * 4 / (param.s * param.c); // channel offset - // int curS = ((tx % 2) * 4 % (param.s * param.c)) / param.c; // kernel r offset - // int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset - const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; #pragma unroll for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { @@ -184,9 +127,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; int inOffset = n * param.c * param.h * param.w ; if(vec_load){ - // const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset - // const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const uint cur0 = fastdiv(start_k + innerColA * 4, layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, @@ -201,7 +141,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int curH = posh_ori + curR * param.d_h; // input h const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ - // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; int inOffsetTmp = layout == 0 ? curH * inChannelOffset + curW * param.c + curC: curC * inChannelOffset + curH * param.w + curW; @@ -211,16 +150,13 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; } else { - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i) smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i){ - // const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset - // const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const uint cur0 = fastdiv(start_k + innerColA * 4 + i, layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, @@ -235,7 +171,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int curH = posh_ori + curR * param.d_h; // input h const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ - // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; int inOffsetTmp = layout == 0 ? curH * inChannelOffset + curW * param.c + curC: curC * inChannelOffset + curH * param.w + curW; @@ -246,40 +181,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } } - - // sts - // for (int i = 0; i < 4; ++i) - // { - // smemweight[weight_sts_addr + i*132] = weight_ldg_reg[i]; - // } - // for (int i = 0; i < 4; ++i) - // { - // smeminput[input_sts_addr + i * 128] = input_ldg_reg[i]; - // } - __syncthreads(); - // if(tx == 0 && bx == 0 && by == 0 && z == 0){ - // for(int i=0; i < 128; ++i) - // printf("%.2f,", smeminput[i]); - // printf("\n"); - // for(int i=128; i < 256; ++i) - // printf("%.2f,", smeminput[i]); - // printf("\n"); - // } - - // if(tx == 0 && bx == 0 && by == 0 && z == 0){ - // printf("%u, %u, %u, %u \n", innerRowA, innerColA, rowStrideA, weight_sts_addr); - // for(int i=0; i < 16; ++i) - // printf("%f,", smemweight[i]); - // printf("\n"); - // for(int i=0; i < 16; ++i) - // printf("%f,", param.weight[i*param.c*param.r*param.s]); - // printf("\n"); - // } - // lds - // int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; const uint input_lds_addr = mma_tid_x * WM; #pragma unroll for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) @@ -288,7 +192,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; - // int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; const uint weight_lds_addr = mma_tid_y * WN; #pragma unroll for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) @@ -297,95 +200,19 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// { -// weight_frag[0][i] = smemweight[weight_lds_addr + i]; -// weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; -// } - // if(tx == 0 && bx == 0 && by == 0 && z == 0) - // { - // printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][0], weight_frag[0][1], weight_frag[0][2], weight_frag[0][3]); - // printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][4], weight_frag[0][5], weight_frag[0][6], weight_frag[0][7]); - // } -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// { -// input_frag[0][i] = smeminput[input_lds_addr + i]; -// input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; -// } - - - for (int crs = start_k; crs < end_k; crs += BK) - { - // ldg -// if (by * BN + tx / 2 < param.k && tx % 2 * 4 < param.c * param.r * param.s){ -// float4 tmp = reinterpret_cast(¶m.weight[by * BN + tx / 2 * weightKOffset + tx % 2 * 4 + crs + 8])[0]; -// weight_ldg_reg[0] = tmp.x; -// weight_ldg_reg[1] = tmp.y; -// weight_ldg_reg[2] = tmp.z; -// weight_ldg_reg[3] = tmp.w; -// } else { -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// weight_ldg_reg[i] = 0.0; -// } - // curR = (crs + 8 + tx % 2 * 4) / (param.s * param.c); // channel offset - // curS = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) / param.c; // kernel r offset - // curC = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) % param.c; // kernel s offset -// curR = fastdiv(crs + 8 + (tx % 2) * 4, param.SC_fastdiv); // channel offset -// curS = fastdiv(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset -// curC = fastmodulo(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - -// int curH = posh_ori + curR; // input h -// int curW = posw_ori + curS; // input w -// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h){ -// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - -// // float4 tmp = reinterpret_cast(¶m.input[inOffset + inOffsetTmp])[0]; -// // input_ldg_reg[0] = tmp.x; -// // input_ldg_reg[1] = tmp.y; -// // input_ldg_reg[2] = tmp.z; -// // input_ldg_reg[3] = tmp.w; -// reinterpret_cast(&input_ldg_reg[0])[0] = reinterpret_cast(¶m.input[inOffset + inOffsetTmp])[0]; } else { -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// input_ldg_reg[i] = 0.0; -// } + for (int crs = start_k; crs < end_k; crs += BK) { int load_flag = write_flag ^ 1; #pragma unroll for (int subcrs = 0; subcrs < BK - 1; ++subcrs) { -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// { -// weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i]; -// weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i + 16]; -// } + #pragma unroll for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) #pragma unroll for (uint i = 0; i < TN; ++i) weight_frag[(subcrs + 1) % 2][wSubColIdx * TN + i] = smemweight[load_flag * (BN+PAD) * BK + (subcrs + 1) * (BN+PAD) + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; - // float* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; - - // // first 4 values -> weight_frag[...][0..3] - // float4 v0 = *reinterpret_cast(base_ptr); - - // // next 4 values (offset +16) -> weight_frag[...][4..7] - // float4 v1 = *reinterpret_cast(base_ptr + 16); - - // // unpack into weight_frag - // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][0]) = v0; - // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][4]) = v1; -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// { -// input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i]; -// input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i + 32]; -// } #pragma unroll for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll @@ -393,15 +220,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * (BM+PAD) * BK + (subcrs + 1) * (BM+PAD) + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; -// #pragma unroll -// for (int i = 0; i < 8; ++i) -// { -// #pragma unroll -// for (int j = 0; j < 8; ++j) -// { -// output_frag[i][j] += weight_frag[subcrs % 2][i] * input_frag[subcrs % 2][j]; -// } -// } // execute warptile matmul #pragma unroll for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { @@ -416,15 +234,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, (wSubColIdx * TN) + resIdxN] += input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * ggml_cuda_cast(weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN]); - // if(tx == 0 && bx == 0 && by == 0 && z == 0){ - // printf("subcrs:%d, i:%d, j:%d, %f * %f = %f, acc = %f\n", subcrs, wSubRowIdx * TM + resIdxM, wSubColIdx * TN + resIdxN, - // input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM], - // weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN], - // input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * - // weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN], - // output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + - // (wSubColIdx * TN) + resIdxN]); - // } } } } @@ -450,12 +259,12 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i) smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; } }else{ - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i){ if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){ // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0]; @@ -474,9 +283,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; int inOffset = n * param.c * param.h * param.w ; if(vec_load){ - // const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset - // const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const uint cur0 = fastdiv(innerColA * 4 + crs + BK, layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, @@ -507,11 +313,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < 4; ++i){ - // const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset - // const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, @@ -527,7 +330,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int curH = posh_ori + curR * param.d_h; // input h const int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ - // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; int inOffsetTmp = layout == 0 ? curH * inChannelOffset + curW * param.c + curC: curC * inChannelOffset + curH * param.w + curW; @@ -538,17 +340,10 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } } - // sts - // for (int i = 0; i < 4; ++i) - // { - // smemweight[write_flag * (BN+4) * 8 + weight_sts_addr + i * (BN+4)] = weight_ldg_reg[i]; - // } - // for (int i = 0; i < 4; ++i) - // { - // smeminput[write_flag * BM * 8 + input_sts_addr + i * BM] = input_ldg_reg[i]; - // } __syncthreads(); + write_flag ^= 1; + #pragma unroll for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll @@ -561,18 +356,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (uint i = 0; i < TN; ++i) weight_frag[0][wSubColIdx * TN + i] = smemweight[(load_flag ^ 1) * (BN+PAD) * BK + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// { -// weight_frag[0][i] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i]; -// weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i + 16]; -// } -// #pragma unroll -// for (int i = 0; i < 4; ++i) -// { -// input_frag[0][i] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i]; -// input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i + 32]; -// } #pragma unroll for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { #pragma unroll @@ -590,100 +373,12 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } } -// #pragma unroll -// for (int i = 0; i < 8; ++i) -// { -// #pragma unroll -// for (int j = 0; j < 8; ++j) -// { -// output_frag[i][j] += weight_frag[1][i] * input_frag[1][j]; -// } -// } } - // if(tx == 59 && bx == 0 && by == 0 && z == 0){ - // for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){ - // printf("%f,", output_frag[i]); - // if((i+1) % (WNITER * TN) == 0) - // printf("\n"); - // } - // printf("\n"); - // } - // if(tx == 59 && bx == 0 && by == 0 && z == 0){ - // int cnt[3] = {0}; - // float values[3] = {-1.f}; - // for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){ - // for(int j = 0; j < 3; j++){ - // if (output_frag[i] == values[j]){ - // cnt[j]++; - // break; - // } else{ - // if (cnt[j] == 0){ - // values[j] = output_frag[i]; - // cnt[j]++; - // break; - // } - // } - // } - // } - // for(int j = 0; j < 3; j++){ - // if(values[j] != -1.f) - // printf("value: %f, cnt: %d \n", values[j], cnt[j]); - // } - // } - // reuse smem float *smemoutput = reinterpret_cast(smem); - // float *smembias = reinterpret_cast(smem + 16 * 1024); - // bias ldg/sts - // if (tx < BN) - // { - // smembias[tx] = param.bias[by * BN + tx]; - // } - - // constexpr uint OUTMITER = (TM * TN * WNITER * WMITER * NUM_THREADS) / (2 * BK * (BM + BN)) / OUTNITER; - // const uint WMITER_TM_OUTMITER = WMITER * TM / OUTMITER; - // const uint WNITER_TN_OUTNITER = WNITER * TN / OUTNITER; - - - -// // uint32_t bias_lds_addr = warp_id / 2 * 32; - -// #pragma unroll -// for (int i = 0; i < 2; ++i) -// { -// #pragma unroll -// for (int j = 0; j < 2; ++j) -// { -// __syncthreads(); - -// #pragma unroll -// for (int subi = 0; subi < 4; ++subi) -// { -// #pragma unroll -// for (int subj = 0; subj < 4; ++subj) -// { -// // output sts -// smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; -// } -// } -// __syncthreads(); - -// #pragma unroll -// for (int subk = 0; subk < 16; ++subk) -// { -// int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; -// if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) -// param.output[outOffset] = smemoutput[output_lds_addr + subk * 32]; -// } -// } -// } const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; - // const uint m_idx = by * BN + mma_tid_y * WN + threadColInWarp * WNITER_TN_OUTNITER; - // const uint n_idx = bx * BM + mma_tid_x * WM + threadRowInWarp * WMITER_TM_OUTMITER; - // const uint output_sts_addr = warp_id * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER * WARPSIZE + - // (threadRowInWarp * (WSUBN / TN) + threadColInWarp) * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER; const uint output_sts_addr = mma_tid_x * BN / WN * TM * TN * WARPSIZE + mma_tid_y * TM * TN * WARPSIZE + threadColInWarp * TN * WSUBM + threadRowInWarp * TM; const uint m_idx = by * BN + mma_tid_y * WN; @@ -716,9 +411,6 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int n = (ksplit > 0) ? gemm_i / PQ : z; const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; if (n < param.n && row < param.k && col < param.Oh * param.Ow){ - // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); - // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) - // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; const uint outOffset = ksplit > 0 ? z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col : @@ -736,8 +428,7 @@ template = GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); @@ -748,7 +439,7 @@ __device__ __forceinline__ void ldmatrix_a( swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes - + // 0 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " @@ -782,7 +473,7 @@ __device__ __forceinline__ void ldmatrix_a( ); src_addr ^= 0b10000; - + // 1 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " @@ -814,7 +505,7 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) : "r"(src_addr + 96 * smem_stride_) ); - + src_addr ^= 0b110000; // 2 @@ -892,31 +583,19 @@ template = GGML_CUDA_CC_TURING + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); uint32_t (®_) [4][8] = reinterpret_cast(reg); -// const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); -// unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); -// uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); -// constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes - -// asm volatile ( -// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " -// "{%0, %1, %2, %3}, [%4];" -// : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) -// : "r"(src_addr) -// ); - // 0 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " @@ -927,18 +606,15 @@ __device__ __forceinline__ void ldmatrix_b( asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) - // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b10000; asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) @@ -946,19 +622,15 @@ __device__ __forceinline__ void ldmatrix_b( ); asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) - // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); -// src_addr += 8 * smem_stride_; src_addr ^= 0b110000; asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) @@ -966,18 +638,15 @@ __device__ __forceinline__ void ldmatrix_b( ); asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) - // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b10000; asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) @@ -985,11 +654,9 @@ __device__ __forceinline__ void ldmatrix_b( ); asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) - // : "r"(src_addr ^ 0b1000000) : "r"(src_addr + 32 * smem_stride_) ); #else @@ -1006,37 +673,27 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, half * __restrict__ output, const param_t param) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - constexpr unsigned int MMA_M = 16; - constexpr unsigned int MMA_N = 8; -// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y ==0) -// printf("conv2d_implicit_kernel launch BM:%d, BN:%d, BK:%d, WM:%d, WN:%d, WK:%d, NUM_THREADS:%d \n", BM, BN, BK, WM, WN, WK, NUM_THREADS); +constexpr unsigned int MMA_M = 16; +constexpr unsigned int MMA_N = 8; + const unsigned int K = param.c * param.r * param.s; -// const uint PQ = param.Oh * param.Ow; const uint inChannelOffset = param.c * param.w; const uint weightKOffset = param.c * param.r * param.s; - // for convenience/readability in index calculations -// const unsigned int A_stride = K; -// const unsigned int B_stride = N; -// const unsigned int CD_stride = N; - - // calculate how many bits of shared memory indices are going to be swizzled, and create masks -// constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8); - // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; - + // calculate block/warp indices const unsigned int block_m = blockIdx.y; const unsigned int block_n = blockIdx.x; const unsigned int warp_m = threadIdx.y; const unsigned int warp_n = threadIdx.x / 32; - + // double buffering extern __shared__ half shmem[]; half* A_block_smem = shmem; @@ -1046,7 +703,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; -// float acc_register_[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; @@ -1056,10 +712,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); // accumulators start at 0 - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) - { - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) - { + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ acc_register_[mma_m][mma_n][0] = 0; acc_register_[mma_m][mma_n][1] = 0; acc_register_[mma_m][mma_n][2] = 0; @@ -1067,9 +721,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } - // these register arrays are used to cache values pre-fetched from global memory during the inner loop of the kernel - // the code is nicer if we hard code it for these tile dimensions and number of threads - // since we performing this copy with float4 pointers, for these tile dimensions it works out to be 8 float4s for A and 4 float4s for B static_assert(BM == 256); static_assert(BN == 256); static_assert(BK == 32); @@ -1078,31 +729,19 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, float4 B_gmem_cache_reg[4]; // prefetch the first block tile of A,B into shared memory -// half* A_block_gmem = input + (block_m * BM * A_stride); + const half* A_block_gmem = input; -// const half* B_block_gmem = kernel + (block_n * weightKOffset); const half* B_block_gmem = kernel + block_n * BN * weightKOffset; tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); - // construct const pointers to warp tiles for use inside the inner loop -// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x ==0 && blockIdx.y ==0){ -// for(int i = 0; i < 32; ++i) -// printf("%.2f,", __half2float(A_block_smem[i])); -// printf("\n"); -// } - int offset_direction = 1; - for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++) - { + for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){ __syncthreads(); - if (block_k != num_block_tiles_k) - { - // half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); + if (block_k != num_block_tiles_k){ const half* A_block_gmem = input; - // const half* B_block_gmem = kernel + (block_n * weightKOffset); const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); @@ -1114,18 +753,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, ldmatrix_b(B_warp_tile, B_register_); // outer product between mma tiles - #pragma unroll - for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) - { - #pragma unroll - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) - { - #pragma unroll - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) - { +#pragma unroll + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++){ +#pragma unroll + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ +#pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ asm volatile ( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " - // "mma.sync.aligned.m16n8k8.row.row.f16.f16.f16.f16 " "{%0, %1}, " "{%2, %3}, " "{%4}, " @@ -1135,53 +770,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); - // asm volatile ( - // "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - // "{%0, %1, %2, %3}," - // "{%4, %5}," - // "{%6}," - // "{%7, %8, %9, %10};\n" - // : "=f"(acc_register_[mma_m][mma_n][0]), "=f"(acc_register_[mma_m][mma_n][1]), - // "=f"(acc_register_[mma_m][mma_n][2]), "=f"(acc_register_[mma_m][mma_n][3]) - // : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), - // "r"(B_register[mma_k][mma_n]), - // "f"(acc_register_[mma_m][mma_n][0]), "f"(acc_register_[mma_m][mma_n][1]), - // "f"(acc_register_[mma_m][mma_n][2]), "f"(acc_register_[mma_m][mma_n][3]) - // ); } } - // if(threadIdx.x == 12 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]), - // __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3])); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[0][0][0], acc_register_[0][0][1], - // acc_register_[0][0][2], acc_register_[0][0][3]); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]), - // __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3])); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]), - // __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3])); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[1][0][0], acc_register_[1][0][1], - // acc_register_[1][0][2], acc_register_[1][0][3]); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[1][mma_k][0]), __half2float(A_register_[1][mma_k][1]), - // __half2float(A_register_[1][mma_k][2]), __half2float(A_register_[1][mma_k][3])); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[3][0][0], acc_register_[3][0][1], - // acc_register_[3][0][2], acc_register_[3][0][3]); - // printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]), - // __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3])); - // printf(" %d, %d: %f, %f, \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1])); - // } - // if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ - // printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1])); - // printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1])); - // } } - // if(threadIdx.x == 0 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){ - // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(acc_register_[3][0][0]), __half2float(acc_register_[3][0][1]), - // __half2float(acc_register_[3][0][2]), __half2float(acc_register_[3][0][3])); - // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(A_register_[3][0][0]), __half2float(A_register_[3][0][1]), - // __half2float(A_register_[3][0][2]), __half2float(A_register_[3][0][3])); - // printf(" %d: %f, %f, %f, %f \n", block_k, __half2float(B_register_[3][0][0]), __half2float(B_register_[3][0][1]), - // __half2float(B_register_[3][0][2]), __half2float(B_register_[3][0][3])); - // } if (block_k != num_block_tiles_k) @@ -1196,8 +787,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } - - // reuse smem half *smemoutput = shmem; const uint lane_id = threadIdx.x % WARPSIZE; @@ -1217,7 +806,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, { for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { - // output sts uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; @@ -1229,20 +817,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } __syncthreads(); - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x ==0 && blockIdx.y ==0){ - // for(int ii = 0; ii < 128; ++ii) - // printf("%.2f,", __half2float(smemoutput[ii])); - // printf("\n"); - // for(int ii = 128; ii < 256; ++ii) - // printf("%.2f,", __half2float(smemoutput[ii])); - // printf("\n"); - // for(int ii = 0; ii < 128; ++ii) - // printf("%.2f,", __half2float(smemoutput[ii*128])); - // printf("\n"); - // for(int ii = 128; ii < 256; ++ii) - // printf("%.2f,", __half2float(smemoutput[ii*128])); - // printf("\n"); - // } #pragma unroll for (int subk = 0; subk < WN / 2; ++subk){ @@ -1252,23 +826,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); if(n < param.n && row < param.k && col < param.Oh * param.Ow){ - // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); - // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) - // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; uint idx = output_lds_addr + subk + j*32*BN/2; idx = idx ^ ((idx & 0b1110000000) >> 4); - // output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2]; output[outOffset] = smemoutput[idx]; - // if(outOffset == 32){ - // printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, - // n, row, col, __half2float(output[outOffset])); - // } } } } } - #else GGML_UNUSED(input); GGML_UNUSED(kernel); @@ -1279,7 +844,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } -#define NUM_VARIANTS 6 +#define NUM_VARIANTS 4 /* conv_shapes[][0]: ne_input=[384,512,256,1],ne_kernel=[3,3,256,256] @@ -1313,12 +878,10 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number int blocky = (P.k + BN-1) / BN; // blocky number int blockz = P.n; // blockz number - // int threadx = NUM; // threadx number per block int thready = 1; // thready number per block int threadz = 1; // threadz number per block dim3 thblock(NUM_THREADS, thready, threadz); dim3 grid(blockx, blocky, blockz); - // int smem_size = 24 * 1024; if(P.c % 4 == 0){ if(P.layout == 0) conv2d_implicit_kernel= GGML_CUDA_CC_VOLTA - // printf("tensor core path called\n"); constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; constexpr unsigned int BK_dim = 32; @@ -1378,10 +940,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa <<>>(X_H, K_D, Y_H.get(), P); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); -// #else -// printf("non tensor path called\n"); -// conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); -// #endif } else{ conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } @@ -1422,13 +980,6 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * // No cwhn GGML_ASSERT(p[7] == false); - // const int IW = input->ne[0]; // input_w - // const int IH = input->ne[1]; // input_h - // const int OW = dst->ne[0]; // output_w - // const int OH = dst->ne[1]; // output_h - // const int KW = kernel->ne[0]; // kernel_w - // const int KH = kernel->ne[1]; // kernel_h - // const int IC = input->ne[2]; // input_channels const int IW = input->ne[LT == 0 ? 1 : 0]; // input_w const int IH = input->ne[LT == 0 ? 2 : 1]; // input_h const int OW = dst->ne[0]; // output_w diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 69942bffac..8ed0109390 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -35,45 +35,8 @@ __device__ __forceinline__ void tileMemcpySwizzleB( half* dst, const unsigned int src_stride, param_t param -) -{ +){ #if __CUDA_ARCH__ >= GGML_CUDA_TURING - // constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; - - // // reinterpret input/output as float4 - // float4* src_float4 = reinterpret_cast(src); - // float4* dst_float4 = reinterpret_cast(dst); - // const unsigned int src_stride_vectorized = src_stride / 8; - - // // # of threads is multiple of # of columns in the tile - // constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; - // static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - - // // flatten out 2d grid of threads into in order of increasing threadIdx.x - // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; - - // // assign each thread a row/column in the tile, calculate how many iterations we need - // // to cover the whole tile - // constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; - // constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - - // #pragma unroll - // for (unsigned int i = 0; i < NUM_ITERS; i++) - // { - // // apply swizzle to the dst index - // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; - // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; - // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); - // if (thread_col * 8 < param.k && start_k + innerColA * 4 < end_k){ - // float4 tmp = reinterpret_cast(&src[thread_row * src_stride_vectorized + thread_col*8)[0]; - // dst_float4[dst_index] = src_float4[src_index]; - // }else{ // read 4 halves - // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); - // } - // thread_row += ROW_STEP; - // } constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; constexpr unsigned int SWIZZLE_BITS_1 = 4; @@ -81,10 +44,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int SWIZZLE_BITS_2 = 2; constexpr unsigned int TILE_COLS = 32; - // reinterpret input/output as float4 - // float4* src_float4 = reinterpret_cast(src); float4* dst_float4 = reinterpret_cast(dst); - // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; @@ -98,14 +58,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - // TODO: next block_k loop const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++) - { + for (unsigned int i = 0; i < NUM_ITERS; i++){ // apply swizzle to the dst index const unsigned int src_index = thread_row * src_stride + thread_col * 8; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; @@ -140,16 +98,14 @@ __device__ __forceinline__ void tileMemcpySwizzleA( ) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; constexpr unsigned int SWIZZLE_BITS_1 = 4; constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; constexpr unsigned int SWIZZLE_BITS_2 = 2; constexpr unsigned int TILE_COLS = 32; - // reinterpret input/output as float4 - // float4* src_float4 = reinterpret_cast(src); float4* dst_float4 = reinterpret_cast(dst); - // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; @@ -166,16 +122,13 @@ __device__ __forceinline__ void tileMemcpySwizzleA( #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++) - { - // unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; - // TODO: next block_k loop const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset @@ -187,7 +140,6 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curR < param.r && curS < param.s && curC < param.c){ - // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; } else{ @@ -201,7 +153,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( GGML_UNUSED(inChannelOffset); GGML_UNUSED(param); NO_DEVICE_CODE; -#endif +#endif } template= GGML_CUDA_TURING - // reinterpret input/output as float4 - // const float4* src_float4 = reinterpret_cast(src); - // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - + // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; @@ -240,19 +188,13 @@ __device__ __forceinline__ void tileMemcpyLoadA( static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++) - { - // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; - // dst_reg[i] = src_float4[src_index]; - // thread_row += ROW_STEP; - // unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; - // TODO: next block_k loop const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset @@ -260,7 +202,6 @@ __device__ __forceinline__ void tileMemcpyLoadA( int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curR < param.r && curS < param.s && curC < param.c){ - // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; } else{ @@ -289,17 +230,13 @@ __device__ __forceinline__ void tileMemcpyLoadB( const unsigned int block_k, const unsigned int src_stride, param_t param -) -{ +){ #if __CUDA_ARCH__ >= GGML_CUDA_TURING - // reinterpret input/output as float4 - // const float4* src_float4 = reinterpret_cast(src); - // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - + // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; @@ -318,11 +255,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++) - { - // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; - // dst_reg[i] = src_float4[src_index]; - // thread_row += ROW_STEP; + for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; @@ -338,7 +271,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( GGML_UNUSED(src_stride); GGML_UNUSED(param); NO_DEVICE_CODE; -#endif +#endif } @@ -354,6 +287,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( ) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; constexpr unsigned int SWIZZLE_BITS_1 = 4; constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; @@ -392,9 +326,9 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( } #else GGML_UNUSED(src_reg); - GGML_UNUSED(dst); + GGML_UNUSED(dst); NO_DEVICE_CODE; -#endif +#endif } __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { @@ -409,15 +343,6 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { return address; } -// constexpr unsigned int int_log2(unsigned int x) -// { -// unsigned int result = 0; -// while (x >>= 1) -// { -// result++; -// } -// return result; -// } #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e5f80fe909..b3948a0bbf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6807,33 +6807,6 @@ static std::vector> make_test_cases_perf() { GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false)); } - for (auto act_case : cases_sd) { - GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); - GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); - - uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; - uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; - - test_cases.emplace_back(new test_conv_2d_implicit( - { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, - { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, - GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, true)); - } - - for (auto act_case : cases_sd) { - GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); - GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); - - uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; - uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; - - test_cases.emplace_back(new test_conv_2d_implicit( - { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, - { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, - GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, true)); - } - - test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); From 55859a86aa466ff5cd6836fb85a21ebd56dde282 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 21:36:03 -0400 Subject: [PATCH 044/122] remove implicit op and related calls; replace conv_2d with conv_2d_implicit kernel --- ggml/include/ggml.h | 14 --- ggml/src/ggml-cpu/ggml-cpu.c | 6 -- ggml/src/ggml-cuda/conv2d-implicit.cu | 120 ++++++++++++++-------- ggml/src/ggml-cuda/ggml-cuda.cu | 6 +- ggml/src/ggml.c | 66 +----------- tests/test-backend-ops.cpp | 139 +++----------------------- tests/test-conv2d-implicit.cpp | 99 ++++-------------- 7 files changed, 117 insertions(+), 333 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 26d6f3332c..b7b472c56e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -513,7 +513,6 @@ extern "C" { GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_3D, GGML_OP_CONV_2D, - GGML_OP_CONV_2D_IMPLICIT, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, @@ -1983,19 +1982,6 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 - GGML_API struct ggml_tensor * ggml_conv_2d_implicitgemm( - struct ggml_context * ctx, - struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] - struct ggml_tensor * b, // input data [W, H, C, N] - int s0, // stride dimension 0 - int s1, // stride dimension 1 - int p0, // padding dimension 0 - int p1, // padding dimension 1 - int d0, // dilation dimension 0 - int d1); - // int layout); // for future - - GGML_API struct ggml_tensor * ggml_conv_3d_direct( struct ggml_context * ctx, struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 6b6efebad5..c131290849 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1887,10 +1887,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_2d(params, tensor); } break; - case GGML_OP_CONV_2D_IMPLICIT: - { - ggml_compute_forward_conv_2d(params, tensor); - } break; case GGML_OP_CONV_3D: { ggml_compute_forward_conv_3d(params, tensor); @@ -2268,7 +2264,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_IM2COL_BACK: case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: - case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: @@ -2794,7 +2789,6 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_CONV_2D: - case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_3D: { cur = GGML_IM2COL_WORK_SIZE; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 6b7efbe789..fcb053c61d 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -7,6 +7,9 @@ typedef unsigned int uint; constexpr uint WARPSIZE = 32; +#define CUDA_NCHW_2_NHWC_TILE_DIM 32 +#define CUDA_NCHW_2_NHWC_BLOCK_NM 8 +#define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 //currently not use; in future for split-k kernels @@ -23,6 +26,41 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri } } +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + int x = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; + int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset + int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + + __shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM]; + + for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ + + const unsigned int imat = blockIdx.z * CUDA_NCHW_2_NHWC_BLOCK_NM + i; + if(imat >= nmat) + break; + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(x < ne01 && y + j < ne00){ + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + tile[row][col] = src[imat*n + (y+j)*ne01 + x]; + } + } + __syncthreads(); + + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(ty + j < ne01 && tx < ne00){ + const int col = (threadIdx.y+j) ^ threadIdx.x; + dst[imat*n + (ty+j)*ne00 + tx] = ggml_cuda_cast(tile[threadIdx.x][col]); + } + } + } +} template<<>>(X_D, K_D, Y_D, P); - else if(P.layout == 1) - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); - } else{ - if(P.layout == 0) - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); - else if(P.layout == 1) - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); - } + + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) { + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { + + int id = ggml_cuda_get_device(); + + int64_t ne = P.c * P.h * P.w * P.n; + int64_t ne00 = P.c; + int64_t ne01 = P.h * P.w; + ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); + + dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); + NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); + + ne = P.c * P.r * P.s * P.k; + ne01 = P.r * P.s; + ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + + const half *X_H = input_f16.get(); + const half *K_H = kernel_f16.get(); + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; constexpr unsigned int BK_dim = 32; @@ -925,19 +977,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); - int id = ggml_cuda_get_device(); - ggml_cuda_pool_alloc x_f16(ctx.pool(id)); - - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(GGML_TYPE_F32); - GGML_ASSERT(to_fp16_cuda != nullptr); - size_t ne = P.c * P.h * P.w * P.n; - x_f16.alloc(ne); - to_fp16_cuda(X_D, x_f16.get(), ne, st); - const half *X_H = x_f16.get(); - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); conv2d_implicit_kernel - <<>>(X_H, K_D, Y_H.get(), P); + <<>>(X_H, K_H, Y_H.get(), P); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); } else{ @@ -971,28 +1013,28 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int PD_Y = p[3]; // padding_y const int DL_X = p[4]; // dilation_x const int DL_Y = p[5]; // dilation_y - const int LT = p[6]; // layout + // const int LT = p[6]; // layout - GGML_ASSERT(LT == 0 || LT == 1); + // GGML_ASSERT(LT == 0 || LT == 1); // same number of input channels - GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]); + // GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]); // No cwhn - GGML_ASSERT(p[7] == false); + GGML_ASSERT(p[6] == false); - const int IW = input->ne[LT == 0 ? 1 : 0]; // input_w - const int IH = input->ne[LT == 0 ? 2 : 1]; // input_h + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h const int OW = dst->ne[0]; // output_w const int OH = dst->ne[1]; // output_h - const int KW = kernel->ne[LT == 0 ? 1 : 0]; // kernel_w - const int KH = kernel->ne[LT == 0 ? 2 : 1]; // kernel_h - const int IC = input->ne[LT == 0 ? 0: 2]; // input_channels + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - + const int64_t total = B * OC * OH * OW; - + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; params.SC_fastdiv = init_fastdiv_values(KW*IC); params.OW_fastdiv = init_fastdiv_values(OW); @@ -1000,7 +1042,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * params.C_fastdiv = init_fastdiv_values(IC); params.RS_fastdiv = init_fastdiv_values(KW*KH); params.S_fastdiv = init_fastdiv_values(KW); - params.layout = LT; + // params.layout = LT; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 154076f38d..29fa63777b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2462,11 +2462,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_im2col_3d(ctx, dst); break; case GGML_OP_CONV_2D: - ggml_cuda_op_conv2d(ctx, dst); - break; - case GGML_OP_CONV_2D_IMPLICIT: ggml_cuda_op_conv2d_implicit(ctx, dst); - break; + break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); break; @@ -3580,7 +3577,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: case GGML_OP_IM2COL_3D: case GGML_OP_CONV_2D: - case GGML_OP_CONV_2D_IMPLICIT: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bfe772697e..03c8dca3e5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -976,7 +976,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL_BACK", "IM2COL_3D", "CONV_2D", - "CONV_2D_IMPLICIT", "CONV_3D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", @@ -1020,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1081,7 +1080,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col_back(x)", "im2col_3d(x)", "conv_2d(x)", - "conv_2d_implicit(x)", "conv_3d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", @@ -1125,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 91, "GGML_OP_COUNT != 91"); +static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4573,66 +4571,6 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } - -// ggml_conv_2d_implicitgemm - -struct ggml_tensor * ggml_conv_2d_implicitgemm( - struct ggml_context * ctx, - struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] - struct ggml_tensor * b, // input data [W, H, C, N] - int s0, // stride dimension 0 - int s1, // stride dimension 1 - int p0, // padding dimension 0 - int p1, // padding dimension 1 - int d0, // dilation dimension 0 - int d1){ - // 0: NHWC, 1:NCHW - // int layout) { - - GGML_ASSERT(a->ne[2] == b->ne[2]); - //GGML_ASSERT(a->type == b->type); - - int64_t ne[4]; - ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); - ne[2] = a->ne[3]; - ne[3] = b->ne[3]; - - struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); - - ggml_set_op_params_i32(result, 0, s0); - ggml_set_op_params_i32(result, 1, s1); - ggml_set_op_params_i32(result, 2, p0); - ggml_set_op_params_i32(result, 3, p1); - ggml_set_op_params_i32(result, 4, d0); - ggml_set_op_params_i32(result, 5, d1); - - struct ggml_tensor *ap, *bp; - if(a->type == GGML_TYPE_F16 && (a->ne[0] > 1 || a->ne[1] > 1)){ - ggml_set_op_params_i32(result, 6, 0); - ap = ggml_reshape_4d(ctx, - ggml_cont(ctx, - ggml_transpose(ctx, - ggml_reshape_3d(ctx, a, a->ne[0]*a->ne[1], a->ne[2], a->ne[3]))), - a->ne[2], a->ne[0], a->ne[1], a->ne[3]); - bp = ggml_reshape_4d(ctx, - ggml_cont(ctx, - ggml_transpose(ctx, - ggml_reshape_3d(ctx, b, b->ne[0]*b->ne[1], b->ne[2], b->ne[3]))), - b->ne[2], b->ne[0], b->ne[1], b->ne[3]); - } else{ - ggml_set_op_params_i32(result, 6, 1); - ap = a; - bp = b; - } - - result->op = GGML_OP_CONV_2D_IMPLICIT; - result->src[0] = ap; - result->src[1] = bp; - - return result; -} - // ggml_conv_3d struct ggml_tensor * ggml_conv_3d_direct( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b3948a0bbf..a7aba2b447 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4191,94 +4191,6 @@ struct test_conv_2d : public test_case { } }; -// CONV_2D_IMPLICIT -struct test_conv_2d_implicit : public test_case { - const std::array ne_input; - const std::array ne_kernel; - const ggml_type type_kernel; - const int stride0; - const int stride1; - const int padding0; - const int padding1; - const int dilation0; - const int dilation1; - // Whether the inputs are contiguous in the channel dim or the width dim - const bool cwhn; - - - - std::string vars() override { - return VARS_TO_STR10(ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); - } - - double max_nmse_err() override { - return 5e-4; - } - - uint64_t op_flops(ggml_tensor * t) override { - GGML_UNUSED(t); - // Just counting matmul costs: - // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops - - // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) - auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { - return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; - }; - - int64_t W = ne_input[0]; - int64_t H = ne_input[1]; - int64_t KW = ne_kernel[0]; - int64_t KH = ne_kernel[1]; - int64_t Cin = ne_kernel[2]; - int64_t Cout = ne_kernel[3]; - int64_t N = ne_input[3]; - int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0); - int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0); - - int64_t K = Cout; - int64_t CRS = Cin * KH * KW; - int64_t NPQ = N * OH * OW; - - return K * NPQ * (2 * CRS - 1); - } - - test_conv_2d_implicit(std::array ne_input = { 64, 64, 16, 1 }, - std::array ne_kernel = { 3, 3, 1, 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1, - int stride1 = 1, int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) : - ne_input(ne_input), - ne_kernel(ne_kernel), - type_kernel(type_kernel), - stride0(stride0), - stride1(stride1), - padding0(padding0), - padding1(padding1), - dilation0(dilation0), - dilation1(dilation1), - cwhn(cwhn) {} - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); - ggml_set_name(input, "input"); - - ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); - ggml_set_name(kernel, "kernel"); - - // if (cwhn) { - // // change memory layout to channel-most-contiguous (CWHN), - // // then permute it back so NE matches the original input - // input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); - // input = ggml_permute(ctx, input, 2, 0, 1, 3); - // kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); - // kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); - // } - - ggml_tensor * out = - ggml_conv_2d_implicitgemm(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1); - ggml_set_name(out, "out"); - return out; - } -}; - // GGML_OP_CONV_2D_DW struct test_conv_2d_dw : public test_case { const std::array ne_input; @@ -5941,30 +5853,6 @@ static std::vector> make_test_cases_eval() { } } - for (uint32_t s0 : { 1, 3 }) { - for (uint32_t p1 : { 2, 5 }) { - for (uint32_t Cin : { 1, 25 }) { - for (uint32_t Cout : { 1, 12 }) { - for (uint32_t KH : { 1, 2, 3, 11 }) { - for (uint32_t KW : { 1, 2, 3, 11 }) { - for (uint32_t H : { 1, 133 }) { - for (uint32_t W : { 1, 141 }) { - if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 && - calc_conv_output_size(H, KH, s1, p1, d1) > 0) { - for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_conv_2d_implicit( - { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, kernel_type, s0, s1, p0, p1, d0, d1, false)); - } - } - } - } - } - } - } - } - } - } - // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) @@ -6732,16 +6620,6 @@ static std::vector> make_test_cases_perf() { } } - for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (auto act_case : cases) { - // Direct CONV_2D - test_cases.emplace_back(new test_conv_2d_implicit( - { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, - { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, - kernel_type, 1, 1, 0, 0, 1, 1, false)); - } - } - // Stable-diffusion layers std::map idx_sd{ { "iw", 0 }, @@ -6788,7 +6666,7 @@ static std::vector> make_test_cases_perf() { uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; - test_cases.emplace_back(new test_conv_2d_implicit( + test_cases.emplace_back(new test_conv_2d( { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); @@ -6801,12 +6679,25 @@ static std::vector> make_test_cases_perf() { uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; - test_cases.emplace_back(new test_conv_2d_implicit( + test_cases.emplace_back(new test_conv_2d( { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false)); } + // for (auto act_case : cases_sd) { + // GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + // GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + // uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + // uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + // test_cases.emplace_back(new test_conv_2d_implicit( + // { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + // { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + // GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); + // } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index 98b8b0e449..7b7a32d9f6 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -239,49 +239,6 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { return gf; } -struct ggml_cgraph * build_graph_2(const test_model& model) { - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; - - - // recalculate for avoid fragmentation - // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - // ggml_set_name(conv2d_res, "conv2d_res"); - // ggml_build_forward_expand(gf, conv2d_res); - // int64_t *ne = conv2d_res->ne; - // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - - struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - // struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - ggml_set_name(wino_res, "wino_res"); - ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; - // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - ggml_free(ctx0); - return gf; -} - - - std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, build_graph_t build_graph, int iters, double *t) { @@ -352,10 +309,10 @@ int main(void) // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1280,1280,26,38,1,1), // std::make_tuple(256,128,768,1024,3,3), - std::make_tuple(128,3,768,1024,3,3), + // std::make_tuple(128,3,768,1024,3,3), // std::make_tuple(256,128,768,1024,1,1), // std::make_tuple(512,256,384,512,1,1), // std::make_tuple(1280,640,52,76,3,3), @@ -389,7 +346,7 @@ int main(void) struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -418,51 +375,31 @@ int main(void) ggml_gallocr_free(allocr); - allocr = NULL; - - allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - - //create the worst case graph for memory usage estimation - gf = build_graph_2(model); - - // compute the required memory - ggml_gallocr_reserve(allocr, gf); - size_t mem_size2 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - - struct ggml_cgraph * gf_res_2 = NULL; - - double run_time2; - std::vector wino_data = compute_graph(model, allocr, build_graph_2, iterations, &run_time2); - - if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | direct TIME | direct VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); - fprintf(stderr, "| --- | --- | --- | --- | --- | --- | --- \n"); + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } - fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), run_time0, mem_size0/1024.0f/1024.0f, - run_time1, mem_size1/1024.0f/1024.0f, - run_time2, mem_size2/1024.0f/1024.0f); + run_time1, mem_size1/1024.0f/1024.0f + ); // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // // float diff = fabs(conv2d_data[i] - wino_data[i]); - // float diff = fabs(im2col_data[i] - wino_data[i]); - // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // wino_data[i], diff, diff1, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From a3b4d8d31eec48fe9977b4e5150f85947f9c2871 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 21:46:15 -0400 Subject: [PATCH 045/122] clean up --- ggml/src/ggml-cuda/cpy.cu | 121 +--------- ggml/src/ggml-cuda/cpy.cuh | 5 - ggml/src/ggml.c | 1 - tests/CMakeLists.txt | 2 - tests/test-conv2d-implicit.cpp | 413 --------------------------------- tests/test-transpose.cpp | 375 ------------------------------ 6 files changed, 5 insertions(+), 912 deletions(-) delete mode 100644 tests/test-conv2d-implicit.cpp delete mode 100644 tests/test-transpose.cpp diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 514657537f..c0a568f4ab 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -37,90 +37,6 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne cpy_1(cx + x_offset, cdst + dst_offset); } - -template -static __global__ void cpy_flt_transpose(const char * cx, char * cdst_direct, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { - - char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct; - - const T* src = reinterpret_cast(cx); - T* dst = reinterpret_cast(cdst); - - const int64_t nmat = ne / (ne00 * ne01); - const int64_t n = ne00 * ne01; - int width = ne01; - int height = ne00; - int x = blockIdx.x * TILE_DIM + threadIdx.x; - int y = blockIdx.y * TILE_DIM + threadIdx.y; - int tx = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset - int ty = blockIdx.x * TILE_DIM + threadIdx.y; - - __shared__ T tile[TILE_DIM][TILE_DIM]; - - for(int i = 0; i < BLOCK_NM; ++i){ - - const unsigned int imat = blockIdx.z * BLOCK_NM + i; - if(imat >= nmat) - break; - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - // if(imat < nmat && x < width && y + j < height){ - if(x < width && y + j < height){ - const unsigned int idx = (y+j)*width + x; - const int row = threadIdx.y+j; - const int col = threadIdx.x ^ row; - // tile[threadIdx.y+j][threadIdx.x] = src[imat*n + idx]; - tile[row][col] = src[imat*n + idx]; - } - } - __syncthreads(); - - - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // printf("BEGIN %d\n", i); - // for(int jj = 0; jj < TILE_DIM; ++jj){ - // for(int ii = 0; ii < TILE_DIM; ++ii) - // printf("%.f, ", tile[jj][ii]); - // printf("]\n"); - // } - // } - - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS){ - // if(imat < nmat && ty + j < width && tx < height){ - if(ty + j < width && tx < height){ - const unsigned int idx = (ty+j)*height + tx; - const int col = (threadIdx.y+j) ^ threadIdx.x; - // dst[imat*n + idx] = tile[threadIdx.x][threadIdx.y + j]; - dst[imat*n + idx] = tile[threadIdx.x][col]; - // if(imat*n + idx == 4*ne00){ - // printf("DEBUG: (%u, %u, %u, %u, %u), j=%d, tx=%d, ty=%d, imat=%u idx=%u dst[%u]=%.2f, %f\n", - // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, j, tx, ty, - // imat, idx, imat*n + idx, dst[imat*n + idx], tile[threadIdx.x][threadIdx.y + j]); - // } - } - } - } - - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // // for(int j = 0; j < 32; ++j){ - // // j = 0; - // for(int i = 0; i < 32; ++i) - // // printf("%.2f, ", src[j*48+i]); - // // printf("%.2f, ", src[j*48+i]); - // printf("%.2f, ", __half2float(src[i])); - // printf("]\n"); - // // } - // printf("==============================\n"); - // // for(int j = 0; j < 32; ++j){ - // for(int i = 0; i < 32; ++i) - // printf("%.2f, ", __half2float(dst[i])); - // printf("]\n"); - // // } - // } -} - static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -228,28 +144,9 @@ static void ggml_cpy_flt_cuda( const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - if constexpr ((std::is_same_v && std::is_same_v || - std::is_same_v && std::is_same_v) - && transpose){ - // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); - // printf("cuda cpy transpose nb00=%d nb01=%d nb10=%d nb11=%d\n", nb00, nb01, nb10, nb11); - // if (ne00 == ne11 && ne01 == ne10 && nb00 == nb11 && nb10 == nb01){ //transpose - // if (transpose) { //transpose - // printf("cuda cpy transpose ne=%d ne00=%d ne01=%d ne10=%d ne11=%d\n", ne, ne00, ne01, ne10, ne11); - dim3 dimGrid( (ne01 + TILE_DIM - 1) / TILE_DIM, - (ne00 + TILE_DIM - 1) / TILE_DIM, - (ne/(ne00*ne01) + BLOCK_NM - 1) / BLOCK_NM ); - dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); - cpy_flt_transpose<<>>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); - } else{ // other - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_flt><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); - } - // } else{ - // cpy_flt><<>> - // (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); - // } + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt><<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } static void ggml_cpy_f32_q8_0_cuda( @@ -435,11 +332,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - if(src0->op_params[10] == 999){ - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); - } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); - } + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { @@ -470,11 +363,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - if(src0->op_params[10] == 999){ - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); - } else { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); - } + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 211348b66a..0bd3c0c6f8 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -2,11 +2,6 @@ #define CUDA_CPY_BLOCK_SIZE 64 -const int TILE_DIM = 32; -const int BLOCK_ROWS = 8; -const int BLOCK_NM = 8; - - void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 03c8dca3e5..a792d6b888 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3612,7 +3612,6 @@ struct ggml_tensor * ggml_transpose( result->op = GGML_OP_TRANSPOSE; result->src[0] = a; - result->op_params[10] = 999; // the transpose flag return result; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1787e53eb5..9171957756 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -198,8 +198,6 @@ if (NOT LLAMA_SANITIZE_ADDRESS) endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) -llama_build_and_test(test-conv2d-implicit.cpp) -llama_build_and_test(test-transpose.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp deleted file mode 100644 index 7b7a32d9f6..0000000000 --- a/tests/test-conv2d-implicit.cpp +++ /dev/null @@ -1,413 +0,0 @@ -#include "ggml.h" -#include "ggml-alloc.h" -#include "ggml-cpu.h" -#include "ggml-backend.h" - -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -//#include -#endif - -#ifdef GGML_USE_METAL -#include "ggml-metal.h" -#endif - -#include -#include -#include -#include -#include -#include -#include -#include - -static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; - fputs(text, stderr); - fflush(stderr); -} - -struct test_model { - struct ggml_tensor * a; - struct ggml_tensor * b; - ggml_backend_t backend = NULL; - ggml_backend_buffer_t buffer; - struct ggml_context * ctx; -}; - - - -void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { - // create data - int KW = kw, KH = kh, IC = ic, OC = oc; - int IW = iw, IH = ih, N = 1; - srand(time(NULL)); - - // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); - - // Initialize adata - std::vector adata(KW * KH * IC * OC); - for (int i = 0; i < KW * KH * IC * OC; i++) { - // adata[i] = 2.f; - // adata[i] = (float)(i%KW)-1.f; - // adata[i] = (rand() % 255) / 255.0; - float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - adata[i] = r; - } - - // Convert adata to fp16 format - std::vector hadata(KW * KH * IC * OC); - ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); - - // Initialize bdata - std::vector bdata(IW * IH * IC * N); - for (int i = 0; i < IW * IH * IC * N; i++) { - // bdata[i] = (float)(i%IW)/10.f; - // bdata[i] = 1.5f; - // bdata[i] = (rand() % 255) / 255.0; - float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - bdata[i] = r; - } - - size_t buffer_size = 0; - { - // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a - buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a - buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b - buffer_size += 1024; // overhead - } - - // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); - // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); - - int num_tensors = 2; - struct ggml_init_params params { - /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - // initialize the backend -#ifdef GGML_USE_CUDA - if (use_gpu) { - // fprintf(stderr, "%s: using CUDA backend\n", __func__); - model.backend = ggml_backend_cuda_init(0); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } - } -#endif - -#ifdef GGML_USE_METAL - if (use_gpu) { - fprintf(stderr, "%s: using Metal backend\n", __func__); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); - model.backend = ggml_backend_metal_init(); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); - } - } -#endif - - if(!model.backend) { - // fallback to CPU backend - model.backend = ggml_backend_cpu_init(); - } - - model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); - - // create context - model.ctx = ggml_init(params); - - // create tensors - model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); - // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); - model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); - - // create a allocator - struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); - - // alloc memory - ggml_tallocr_alloc(&alloc, model.a); - - // load data to buffer - if(ggml_backend_is_cpu(model.backend)) { - memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); - // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); - } else { - ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); - // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); - } - - // alloc memory - ggml_tallocr_alloc(&alloc, model.b); - - if(ggml_backend_is_cpu(model.backend) -#ifdef GGML_USE_METAL - || ggml_backend_is_metal(model.backend) -#endif - ) { - memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); - } else { - ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); - } -} - -typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); - -struct ggml_cgraph * build_graph_0(const test_model& model) { - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; - - - - // recalculate for avoid fragmentation - struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - ggml_set_name(conv2d_res, "conv2d_res"); - ggml_build_forward_expand(gf, conv2d_res); - // int64_t *ne = conv2d_res->ne; - // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - - // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); - // ggml_set_name(wino_res, "wino_res"); - // ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; - // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - ggml_free(ctx0); - return gf; -} - -struct ggml_cgraph * build_graph_1(const test_model& model) { - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; - - - - // recalculate for avoid fragmentation - // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - // ggml_set_name(conv2d_res, "conv2d_res"); - // ggml_build_forward_expand(gf, conv2d_res); - // int64_t *ne = conv2d_res->ne; - // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - - // struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - ggml_set_name(wino_res, "wino_res"); - ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; - // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - ggml_free(ctx0); - return gf; -} - - -std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, - build_graph_t build_graph, int iters, double *t) { - struct ggml_cgraph * gf = build_graph(model); - - - // allocate tensors - ggml_gallocr_alloc_graph(allocr, gf); - int n_threads = 1; - - if (ggml_backend_is_cpu(model.backend)) { - ggml_backend_cpu_set_n_threads(model.backend, n_threads); - } - -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif - - - - ggml_backend_graph_compute(model.backend, gf); - - ggml_backend_synchronize(model.backend); - - int64_t start_time = ggml_time_us(); - - for(int iter=0; iter data(ggml_nelements(res)); - ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); - - *t = time_us/1000; - return data; - -} - - -int main(void) -{ - ggml_time_init(); - std::vector> configs = { - // std::make_tuple(64,64,48,64,3,3), - // std::make_tuple(320,320,104,152,3,3), - // std::make_tuple(640,640,52,76,3,3), - // std::make_tuple(640,640,104,152,3,3), - // std::make_tuple(960,320,104,152,3,3), - std::make_tuple(1280,1280,26,38,3,3), - // std::make_tuple(1280,1280,26,38,1,1), - // std::make_tuple(256,128,768,1024,3,3), - // std::make_tuple(128,3,768,1024,3,3), - // std::make_tuple(256,128,768,1024,1,1), - // std::make_tuple(512,256,384,512,1,1), - // std::make_tuple(1280,640,52,76,3,3), - // std::make_tuple(1920,1280,26,38,3,3), - // std::make_tuple(2560,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), - // std::make_tuple(512,512,208,304,3,3), - // std::make_tuple(512,256,416,608,3,3), - // std::make_tuple(256,128,832,1216,3,3), - // std::make_tuple(256,256,832,1216,3,3), - // std::make_tuple(320,256,1024,1920) - }; - - int k = 0; - - for (auto c : configs){ - test_model model; - load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), - std::get<3>(c), std::get<4>(c), std::get<5>(c), true); - - ggml_gallocr_t allocr = NULL; - allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - - //create the worst case graph for memory usage estimation - struct ggml_cgraph * gf = build_graph_0(model); - - // compute the required memory - ggml_gallocr_reserve(allocr, gf); - size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - - struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 0; - - double run_time0; - std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); - - ggml_gallocr_free(allocr); - - allocr = NULL; - - allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - - //create the worst case graph for memory usage estimation - gf = build_graph_1(model); - - // compute the required memory - ggml_gallocr_reserve(allocr, gf); - size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - - struct ggml_cgraph * gf_res_1 = NULL; - - double run_time1; - // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); - std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); - - - ggml_gallocr_free(allocr); - - if(k==0) { - k = 1; - fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); - fprintf(stderr, "| --- | --- | --- | --- | --- \n"); - } - - fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", - std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), - run_time0, mem_size0/1024.0f/1024.0f, - run_time1, mem_size1/1024.0f/1024.0f - ); - - - // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } - - ggml_free(model.ctx); - ggml_backend_buffer_free(model.buffer); - ggml_backend_free(model.backend); - ggml_gallocr_free(allocr); - - } - - // printf("\nPerforming test:\n"); - return 0; -} diff --git a/tests/test-transpose.cpp b/tests/test-transpose.cpp deleted file mode 100644 index 73263f3438..0000000000 --- a/tests/test-transpose.cpp +++ /dev/null @@ -1,375 +0,0 @@ -#include "ggml.h" -#include "ggml-alloc.h" -#include "ggml-cpu.h" -#include "ggml-backend.h" - -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -//#include -#endif - -#ifdef GGML_USE_METAL -#include "ggml-metal.h" -#endif - -#include -#include -#include -#include -#include -#include -#include -#include - -static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; - fputs(text, stderr); - fflush(stderr); -} - -struct test_model { - struct ggml_tensor * a; - struct ggml_tensor * b; - ggml_backend_t backend = NULL; - ggml_backend_buffer_t buffer; - struct ggml_context * ctx; -}; - - - -void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { - // create data - int KW = kw, KH = kh, IC = ic, OC = oc; - int IW = iw, IH = ih, N = 1; - srand(time(NULL)); - - // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); - - // Initialize adata - std::vector adata(KW * KH * IC * OC); - for (int i = 0; i < KW * KH * IC * OC; i++) { - // adata[i] = 2.f; - adata[i] = (float)i; - // adata[i] = (rand() % 255) / 255.0; - // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - // adata[i] = r; - } - - // Convert adata to fp16 format - std::vector hadata(KW * KH * IC * OC); - ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); - - // Initialize bdata - std::vector bdata(IW * IH * IC * N); - for (int i = 0; i < IW * IH * IC * N; i++) { - // bdata[i] = (float)(i%IW)/10.f; - // bdata[i] = 1.5f; - bdata[i] = (float)(i+1); - // bdata[i] = (rand() % 255) / 255.0; - // float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - // bdata[i] = r; - } - - // for(int i = 0; i < IH; i++) { - // // float diff = fabs(conv2d_data[i] - wino_data[i]); - // for(int j = 0; j < IW; j++) { - // printf("%.0f, ", bdata[i*IW+j]); - // } - // printf("\n"); - // } - for(int i = 0; i < KH; i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - for(int j = 0; j < KW; j++) { - printf("%.0f, ", adata[i*KW+j]); - } - printf("\n"); - } - printf(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"); - - size_t buffer_size = 0; - { - // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a - buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a - buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b - buffer_size += 1024; // overhead - } - - // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); - // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); - - int num_tensors = 2; - struct ggml_init_params params { - /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - // initialize the backend -#ifdef GGML_USE_CUDA - if (use_gpu) { - // fprintf(stderr, "%s: using CUDA backend\n", __func__); - model.backend = ggml_backend_cuda_init(0); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } - } -#endif - -#ifdef GGML_USE_METAL - if (use_gpu) { - fprintf(stderr, "%s: using Metal backend\n", __func__); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); - model.backend = ggml_backend_metal_init(); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); - } - } -#endif - - if(!model.backend) { - // fallback to CPU backend - model.backend = ggml_backend_cpu_init(); - } - - model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); - - // create context - model.ctx = ggml_init(params); - - // create tensors - model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); - // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); - model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); - - int64_t *ne = model.a->ne; - printf("before trans: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - // create a allocator - struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); - - // alloc memory - ggml_tallocr_alloc(&alloc, model.a); - - // load data to buffer - if(ggml_backend_is_cpu(model.backend)) { - memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); - // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); - } else { - ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); - // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); - } - - // alloc memory - ggml_tallocr_alloc(&alloc, model.b); - - if(ggml_backend_is_cpu(model.backend) -#ifdef GGML_USE_METAL - || ggml_backend_is_metal(model.backend) -#endif - ) { - memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); - } else { - ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); - } -} - -typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); - -struct ggml_cgraph * build_graph_0(const test_model& model) { - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; - - - - // recalculate for avoid fragmentation - // struct ggml_tensor* conv2d_res = ggml_cont(ctx0, ggml_transpose(ctx0, model.b)); - struct ggml_tensor* conv2d_res = ggml_cont(ctx0, ggml_transpose(ctx0, model.a)); - ggml_set_name(conv2d_res, "transpose_res"); - ggml_build_forward_expand(gf, conv2d_res); - int64_t *ne = conv2d_res->ne; - printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - - // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); - // ggml_set_name(wino_res, "wino_res"); - // ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; - // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - ggml_free(ctx0); - return gf; -} - - - -std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, - build_graph_t build_graph, int iters, double *t) { - struct ggml_cgraph * gf = build_graph(model); - - - // allocate tensors - ggml_gallocr_alloc_graph(allocr, gf); - int n_threads = 1; - - if (ggml_backend_is_cpu(model.backend)) { - ggml_backend_cpu_set_n_threads(model.backend, n_threads); - } - -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif - - ggml_backend_synchronize(model.backend); - - ggml_backend_graph_compute(model.backend, gf); - - ggml_backend_synchronize(model.backend); - - int64_t start_time = ggml_time_us(); - - for(int iter=0; iter data(ggml_nelements(res)); - std::vector fdata(ggml_nelements(res)); - std::vector data(ggml_nelements(res)); - ggml_backend_tensor_get(res, fdata.data(), 0, ggml_nbytes(res)); - ggml_fp16_to_fp32_row(fdata.data(), data.data(), ggml_nelements(res)); - *t = time_us/1000; - return data; - -} - - -int main(void) -{ - ggml_time_init(); - std::vector> configs = { - // std::make_tuple(64,64,48,64,3,3), - // std::make_tuple(320,320,104,152,3,3), - // std::make_tuple(640,640,52,76,3,3), - // std::make_tuple(640,640,104,152,3,3), - // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1,128,38,49,3,3), - std::make_tuple(1,1,38,49,38,49), - // std::make_tuple(1280,1280,26,38,1,1), - // std::make_tuple(256,128,768,1024,3,3), - // std::make_tuple(256,128,768,1024,1,1), - // std::make_tuple(1280,640,52,76,3,3), - // std::make_tuple(1920,1280,26,38,3,3), - // std::make_tuple(2560,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), - // std::make_tuple(512,512,208,304,3,3), - // std::make_tuple(512,256,416,608,3,3), - // std::make_tuple(256,128,832,1216,3,3), - // std::make_tuple(256,256,832,1216,3,3), - // std::make_tuple(320,256,1024,1920) - }; - - int k = 0; - - for (auto c : configs){ - test_model model; - load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), - std::get<3>(c), std::get<4>(c), std::get<5>(c), true); - - ggml_gallocr_t allocr = NULL; - allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - - //create the worst case graph for memory usage estimation - struct ggml_cgraph * gf = build_graph_0(model); - - // compute the required memory - ggml_gallocr_reserve(allocr, gf); - size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - - struct ggml_cgraph * gf_res_0 = NULL; - int iterations = 0; - - double run_time0; - std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); - - - - - - - - //create the worst case graph for memory usage estimation - - - - - - - - // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < std::get<2>(c); i++) { - // // float diff = fabs(conv2d_data[i] - wino_data[i]); - // for(int j = 0; j < std::get<3>(c); j++) { - // printf("%4.1f, ", im2col_data[i*std::get<3>(c)+j]); - // } - // printf("\n"); - // } - for(int i = 0; i < std::get<4>(c); i++) { - // float diff = fabs(conv2d_data[i] - wino_data[i]); - for(int j = 0; j < std::get<5>(c); j++) { - printf("%4.1f, ", im2col_data[i*std::get<5>(c)+j]); - } - printf("\n"); - } - - ggml_free(model.ctx); - ggml_backend_buffer_free(model.buffer); - ggml_backend_free(model.backend); - ggml_gallocr_free(allocr); - - } - - // printf("\nPerforming test:\n"); - return 0; -} From 70132278cb50e95955126459364b042c81757d4f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 21:57:12 -0400 Subject: [PATCH 046/122] more clean up --- ggml/src/ggml-cuda/cpy.cu | 6 +++--- ggml/src/ggml.c | 2 +- tests/test-backend-ops.cpp | 36 +++++------------------------------- 3 files changed, 9 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c0a568f4ab..8567c3d5a1 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -138,7 +138,7 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des #endif } -template +template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -332,7 +332,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { @@ -363,7 +363,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a792d6b888..50dc1aa24f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4570,7 +4570,7 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } -// ggml_conv_3d +// ggml_conv_3d_direct struct ggml_tensor * ggml_conv_3d_direct( struct ggml_context * ctx, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a7aba2b447..177288c811 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2414,7 +2414,6 @@ struct test_cpy : public test_case { const std::array permute_dst; bool _src_use_permute; bool _dst_use_permute; - bool is_transpose; std::string vars() override { return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst); @@ -2431,12 +2430,10 @@ struct test_cpy : public test_case { test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, std::array ne = {10, 10, 10, 1}, std::array permute_src = {0, 0, 0, 0}, - std::array permute_dst = {0, 0, 0, 0}, - bool transpose = false) + std::array permute_dst = {0, 0, 0, 0}) : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst), _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0), - _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0), - is_transpose(transpose) {} + _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); @@ -2457,8 +2454,6 @@ struct test_cpy : public test_case { } ggml_tensor * out = ggml_cpy(ctx, src, dst); - if(is_transpose) - src->op_params[10] = 999; ggml_set_name(out, "out"); return out; @@ -6024,7 +6019,6 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4})); test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {48, 48, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cont()); test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); @@ -6685,32 +6679,12 @@ static std::vector> make_test_cases_perf() { GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false)); } - // for (auto act_case : cases_sd) { - // GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); - // GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); - - // uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; - // uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; - - // test_cases.emplace_back(new test_conv_2d_implicit( - // { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, - // { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, - // GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); - // } - test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); - // test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, false)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, false)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, true)); - test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}, false)); - + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); From 1f3d5eb8e9e0ac6009c951da7de1649e23d8b70b Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 22:47:03 -0400 Subject: [PATCH 047/122] prevent CI compile failure --- ggml/src/ggml-cuda/conv2d-implicit.cu | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index fcb053c61d..3bb0df8bdc 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1007,12 +1007,12 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ctx.device].cc; const int32_t * p = (const int32_t *) dst->op_params; - const int ST_X = p[0]; // stride_x - const int ST_Y = p[1]; // stride_y - const int PD_X = p[2]; // padding_x - const int PD_Y = p[3]; // padding_y - const int DL_X = p[4]; // dilation_x - const int DL_Y = p[5]; // dilation_y + const uint ST_X = p[0]; // stride_x + const uint ST_Y = p[1]; // stride_y + const uint PD_X = p[2]; // padding_x + const uint PD_Y = p[3]; // padding_y + const uint DL_X = p[4]; // dilation_x + const uint DL_Y = p[5]; // dilation_y // const int LT = p[6]; // layout // GGML_ASSERT(LT == 0 || LT == 1); @@ -1022,16 +1022,16 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * // No cwhn GGML_ASSERT(p[6] == false); - const int IW = input->ne[0]; // input_w - const int IH = input->ne[1]; // input_h - const int OW = dst->ne[0]; // output_w - const int OH = dst->ne[1]; // output_h - const int KW = kernel->ne[0]; // kernel_w - const int KH = kernel->ne[1]; // kernel_h - const int IC = input->ne[2]; // input_channels + const uint IW = input->ne[0]; // input_w + const uint IH = input->ne[1]; // input_h + const uint OW = dst->ne[0]; // output_w + const uint OH = dst->ne[1]; // output_h + const uint KW = kernel->ne[0]; // kernel_w + const uint KH = kernel->ne[1]; // kernel_h + const uint IC = input->ne[2]; // input_channels - const int OC = kernel->ne[3]; // ouptut_chanles - const int B = input->ne[3]; // n_batches + const uint OC = kernel->ne[3]; // ouptut_chanles + const uint B = input->ne[3]; // n_batches const int64_t total = B * OC * OH * OW; From c141ce3533bc8f112ac7e1492fc75a3ab12b51a3 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 22:56:27 -0400 Subject: [PATCH 048/122] make CI happy --- ggml/src/ggml-cuda/conv2d-implicit.cu | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 3bb0df8bdc..521329d085 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -13,18 +13,18 @@ constexpr uint WARPSIZE = 32; //currently not use; in future for split-k kernels -static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { - const int row = blockIdx.x; - const int col = threadIdx.x; +// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { +// const int row = blockIdx.x; +// const int col = threadIdx.x; - float sum = 0.0f; - if (row * blockDim.x + col < ncols) { - for (int i = 0; i < nrows; ++i){ - sum += x[i * ncols + row * blockDim.x + col]; - } - dst[row * blockDim.x + col] = sum; - } -} +// float sum = 0.0f; +// if (row * blockDim.x + col < ncols) { +// for (int i = 0; i < nrows; ++i){ +// sum += x[i * ncols + row * blockDim.x + col]; +// } +// dst[row * blockDim.x + col] = sum; +// } +// } template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -1033,8 +1033,6 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint OC = kernel->ne[3]; // ouptut_chanles const uint B = input->ne[3]; // n_batches - const int64_t total = B * OC * OH * OW; - param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; params.SC_fastdiv = init_fastdiv_values(KW*IC); params.OW_fastdiv = init_fastdiv_values(OW); From 2b5351a898718314ca841df3e6cf995c03e65309 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 23:17:36 -0400 Subject: [PATCH 049/122] make CI happy --- ggml/src/ggml-cuda/conv2d-implicit.cu | 16 ++++++++-------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 1 - 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 521329d085..681ae45ee1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -990,6 +990,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + GGML_UNUSED(ctx); } void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -1033,14 +1034,13 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint OC = kernel->ne[3]; // ouptut_chanles const uint B = input->ne[3]; // n_batches - param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; - params.SC_fastdiv = init_fastdiv_values(KW*IC); - params.OW_fastdiv = init_fastdiv_values(OW); - params.OHOW_fastdiv = init_fastdiv_values(OW*OH); - params.C_fastdiv = init_fastdiv_values(IC); - params.RS_fastdiv = init_fastdiv_values(KW*KH); - params.S_fastdiv = init_fastdiv_values(KW); - // params.layout = LT; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW, + init_fastdiv_values(KW*IC), + init_fastdiv_values(OW), + init_fastdiv_values(IC), + init_fastdiv_values(KW*KH), + init_fastdiv_values(KW), + init_fastdiv_values(OW*OH)}; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 8ed0109390..a3bfea687a 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -17,7 +17,6 @@ typedef struct{ unsigned int d_w; //dilation width unsigned int Oh; //output height unsigned int Ow; //output width - unsigned int layout; uint3 SC_fastdiv; uint3 OW_fastdiv; uint3 C_fastdiv; From c1f67c19e04bdc4659d20c61ec01199190b95e2b Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 29 Oct 2025 23:23:21 -0400 Subject: [PATCH 050/122] make CI happy --- ggml/src/ggml-cuda/conv2d-implicit.cu | 21 +++++++++++---------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 10 +++++----- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 681ae45ee1..1b00a3dcea 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -85,8 +85,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // Warp tile const uint lane_id = tx % WARPSIZE; const uint warp_id = tx / WARPSIZE; - const int mma_tid_x = warp_id / (BN / WN); - const int mma_tid_y = warp_id % (BN / WN); + const int mma_tid_x = warp_id / (BN / WN); + const int mma_tid_y = warp_id % (BN / WN); // size of the warp subtile constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); @@ -449,7 +449,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int n = (ksplit > 0) ? gemm_i / PQ : z; const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; if (n < param.n && row < param.k && col < param.Oh * param.Ow){ - const uint outOffset = ksplit > 0 ? + const uint outOffset = ksplit > 0 ? z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col : z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; @@ -626,7 +626,7 @@ __device__ __forceinline__ void ldmatrix_b( static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); - + uint32_t (®_) [4][8] = reinterpret_cast(reg); unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); @@ -739,11 +739,11 @@ constexpr unsigned int MMA_N = 8; constexpr int BUFFER_SIZE = BM * BK + BK * BN; // declare register storage - // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together + // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; - + // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); @@ -827,7 +827,7 @@ constexpr unsigned int MMA_N = 8; // reuse smem half *smemoutput = shmem; - const uint lane_id = threadIdx.x % WARPSIZE; + const uint lane_id = threadIdx.x % WARPSIZE; const uint mma_row = lane_id / 4; const uint mma_col = lane_id % 4; const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2; @@ -845,7 +845,7 @@ constexpr unsigned int MMA_N = 8; for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint idx = output_sts_addr + + uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); @@ -902,7 +902,7 @@ constexpr static int conv_shapes[][NUM_VARIANTS] = { }; template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { const uint BM = conv_shapes[0][CONV_SHAPE]; const uint BN = conv_shapes[1][CONV_SHAPE]; @@ -920,7 +920,7 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, int threadz = 1; // threadz number per block dim3 thblock(NUM_THREADS, thready, threadz); dim3 grid(blockx, blocky, blockz); - + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } @@ -991,6 +991,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); GGML_UNUSED(ctx); + GGML_UNUSED(cc); } void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index a3bfea687a..347ca12b3e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -137,7 +137,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curR < param.r && curS < param.s && curC < param.c){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; @@ -199,7 +199,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curR < param.r && curS < param.s && curC < param.c){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; @@ -215,7 +215,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( GGML_UNUSED(inChannelOffset); GGML_UNUSED(param); NO_DEVICE_CODE; -#endif +#endif } @@ -299,7 +299,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - + // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; @@ -312,7 +312,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { From 417cfc3cc6ce0a72c3c139cf8fd68c9718966ca1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 31 Oct 2025 19:57:28 -0400 Subject: [PATCH 051/122] added a test case to directly compare im2col and implicit gemm --- tests/CMakeLists.txt | 1 + tests/test-conv2d.cpp | 413 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 414 insertions(+) create mode 100644 tests/test-conv2d.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9171957756..aaabfe91b7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -198,6 +198,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS) endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) +llama_build_and_test(test-conv2d.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp new file mode 100644 index 0000000000..c2cc1930cb --- /dev/null +++ b/tests/test-conv2d.cpp @@ -0,0 +1,413 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +//#include +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + + + +void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { + // create data + int KW = kw, KH = kh, IC = ic, OC = oc; + int IW = iw, IH = ih, N = 1; + srand(time(NULL)); + + // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); + + // Initialize adata + std::vector adata(KW * KH * IC * OC); + for (int i = 0; i < KW * KH * IC * OC; i++) { + // adata[i] = 2.f; + // adata[i] = (float)(i%KW)-1.f; + // adata[i] = (rand() % 255) / 255.0; + float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + adata[i] = r; + } + + // Convert adata to fp16 format + std::vector hadata(KW * KH * IC * OC); + ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); + + // Initialize bdata + std::vector bdata(IW * IH * IC * N); + for (int i = 0; i < IW * IH * IC * N; i++) { + // bdata[i] = (float)(i%IW)/10.f; + // bdata[i] = 1.5f; + // bdata[i] = (rand() % 255) / 255.0; + float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); + bdata[i] = r; + } + + size_t buffer_size = 0; + { + // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a + buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUDA + if (use_gpu) { + // fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); + // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); + // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); + // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); + } +} + +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); + +struct ggml_cgraph * build_graph_0(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(conv2d_res, "conv2d_res"); + ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + // ggml_set_name(wino_res, "wino_res"); + // ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * build_graph_1(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // ggml_set_name(conv2d_res, "conv2d_res"); + // ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(wino_res, "wino_res"); + ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + + + + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, double *t) { + struct ggml_cgraph * gf = build_graph(model); + + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + + + ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_synchronize(model.backend); + + int64_t start_time = ggml_time_us(); + + for(int iter=0; iter data(ggml_nelements(res)); + ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); + + *t = time_us/1000; + return data; + +} + + +int main(void) +{ + ggml_time_init(); + std::vector> configs = { + // std::make_tuple(64,64,48,64,3,3), + // std::make_tuple(320,320,104,152,3,3), + // std::make_tuple(640,640,52,76,3,3), + // std::make_tuple(640,640,104,152,3,3), + // std::make_tuple(960,320,104,152,3,3), + // std::make_tuple(1280,1280,26,38,3,3), + // std::make_tuple(1280,1280,26,38,1,1), + // std::make_tuple(256,128,768,1024,3,3), + // std::make_tuple(128,3,768,1024,3,3), + // std::make_tuple(256,128,768,1024,1,1), + // std::make_tuple(512,256,384,512,1,1), + // std::make_tuple(1280,640,52,76,3,3), + // std::make_tuple(1920,1280,26,38,3,3), + // std::make_tuple(2560,1280,26,38,3,3), + std::make_tuple(320,1280,26,38,3,3), + // std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,208,304,3,3), + // std::make_tuple(512,256,416,608,3,3), + // std::make_tuple(256,128,832,1216,3,3), + // std::make_tuple(256,256,832,1216,3,3), + // std::make_tuple(320,256,1024,1920) + }; + + int k = 0; + + for (auto c : configs){ + test_model model; + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), + std::get<3>(c), std::get<4>(c), std::get<5>(c), true); + + ggml_gallocr_t allocr = NULL; + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph_0(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_0 = NULL; + int iterations = 20; + + double run_time0; + std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); + + ggml_gallocr_free(allocr); + + allocr = NULL; + + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + gf = build_graph_1(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_1 = NULL; + + double run_time1; + // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + + if(k==0) { + k = 1; + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + } + + fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), + run_time0, mem_size0/1024.0f/1024.0f, + run_time1, mem_size1/1024.0f/1024.0f); + + + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + // for(int i = 0; i < 26*38; i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { + // // float diff = fabs(conv2d_data[i] - wino_data[i]); + // float diff = fabs(im2col_data[i] - wino_data[i]); + // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // wino_data[i], diff, diff1, i); + // // break; + // // } + // } + + ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + + } + + // printf("\nPerforming test:\n"); + return 0; +} From f95664c76c478510cec67d463fe73266bf43fcf4 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 1 Nov 2025 14:35:44 -0400 Subject: [PATCH 052/122] make tensor core path available for cc 7.5 and above --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1b00a3dcea..6bc93b2a57 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -927,7 +927,7 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { + if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { int id = ggml_cuda_get_device(); From fa9e415c9be1f7d29d614b56fa8c435459cc4736 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 15:48:57 -0500 Subject: [PATCH 053/122] minor update of test case --- tests/test-conv2d.cpp | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index c2cc1930cb..afca57459a 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -306,26 +306,27 @@ int main(void) { ggml_time_init(); std::vector> configs = { - // std::make_tuple(64,64,48,64,3,3), - // std::make_tuple(320,320,104,152,3,3), - // std::make_tuple(640,640,52,76,3,3), - // std::make_tuple(640,640,104,152,3,3), - // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), - // std::make_tuple(1280,1280,26,38,1,1), - // std::make_tuple(256,128,768,1024,3,3), - // std::make_tuple(128,3,768,1024,3,3), - // std::make_tuple(256,128,768,1024,1,1), - // std::make_tuple(512,256,384,512,1,1), - // std::make_tuple(1280,640,52,76,3,3), - // std::make_tuple(1920,1280,26,38,3,3), - // std::make_tuple(2560,1280,26,38,3,3), + std::make_tuple(64,64,48,64,3,3), + std::make_tuple(320,320,104,152,3,3), + std::make_tuple(640,640,52,76,3,3), + std::make_tuple(640,640,104,152,3,3), + std::make_tuple(960,320,104,152,3,3), + std::make_tuple(1280,1280,26,38,3,3), std::make_tuple(320,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), - // std::make_tuple(512,512,208,304,3,3), - // std::make_tuple(512,256,416,608,3,3), - // std::make_tuple(256,128,832,1216,3,3), - // std::make_tuple(256,256,832,1216,3,3), + std::make_tuple(1280,1280,26,38,1,1), + std::make_tuple(256,128,768,1024,3,3), + std::make_tuple(128,3,768,1024,3,3), + std::make_tuple(256,128,768,1024,1,1), + std::make_tuple(512,256,384,512,1,1), + std::make_tuple(1280,640,52,76,3,3), + std::make_tuple(1920,1280,26,38,3,3), + std::make_tuple(2560,1280,26,38,3,3), + std::make_tuple(320,1280,26,38,3,3), + std::make_tuple(512,512,104,152,3,3), + std::make_tuple(512,512,208,304,3,3), + std::make_tuple(512,256,416,608,3,3), + std::make_tuple(256,128,832,1216,3,3), + std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) }; @@ -377,7 +378,7 @@ int main(void) if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } From 27881fbe7b8c358cb2736f507ae742f9c3732d4d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 19:43:55 -0500 Subject: [PATCH 054/122] fixes for CI --- tests/test-conv2d.cpp | 34 ++++++++++------------------------ 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index afca57459a..d7560be8ff 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -36,7 +36,9 @@ struct test_model { struct ggml_context * ctx; }; - +void load_model(test_model &, int, int, int, int, int, int, bool); +struct ggml_cgraph * build_graph_0(const test_model&); +struct ggml_cgraph * build_graph_1(const test_model&); void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { // create data @@ -102,7 +104,6 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, #ifdef GGML_USE_METAL if (use_gpu) { fprintf(stderr, "%s: using Metal backend\n", __func__); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); model.backend = ggml_backend_metal_init(); if (!model.backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); @@ -178,8 +179,6 @@ struct ggml_cgraph * build_graph_0(const test_model& model) { int d0 = 1; int d1 = 1; - - // recalculate for avoid fragmentation struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); ggml_set_name(conv2d_res, "conv2d_res"); @@ -219,8 +218,6 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { int d0 = 1; int d1 = 1; - - // recalculate for avoid fragmentation // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); // ggml_set_name(conv2d_res, "conv2d_res"); @@ -239,7 +236,8 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { return gf; } - +std::vector compute_graph(const test_model &, ggml_gallocr_t, + build_graph_t, int, double *); std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, @@ -255,14 +253,6 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr ggml_backend_cpu_set_n_threads(model.backend, n_threads); } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif - - - ggml_backend_graph_compute(model.backend, gf); ggml_backend_synchronize(model.backend); @@ -274,13 +264,11 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr ggml_backend_synchronize(model.backend); } - // ggml_backend_synchronize(model.backend); int64_t end_time = ggml_time_us(); double time_us = end_time - start_time; time_us = time_us/iters; - // printf(" Taking %f ms\n ", time_us/1000); - + //ggml_graph_print(gf); struct ggml_tensor *res = NULL; @@ -334,7 +322,7 @@ int main(void) for (auto c : configs){ test_model model; - load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); ggml_gallocr_t allocr = NULL; @@ -349,7 +337,6 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - struct ggml_cgraph * gf_res_0 = NULL; int iterations = 20; double run_time0; @@ -368,15 +355,14 @@ int main(void) ggml_gallocr_reserve(allocr, gf); size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - struct ggml_cgraph * gf_res_1 = NULL; + double run_time1; // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); - if(k==0) { + if(k==0) { k = 1; fprintf(stderr, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); fprintf(stderr, "| --- | --- | --- | --- | --- \n"); @@ -409,6 +395,6 @@ int main(void) } - // printf("\nPerforming test:\n"); + // printf("\nPerforming test:\n"); return 0; } From 8572313000c1d2d783ee3efe19adbaf293ab1dce Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 19:45:22 -0500 Subject: [PATCH 055/122] remove trailing blank --- tests/test-conv2d.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index d7560be8ff..2dda7e735a 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -368,7 +368,7 @@ int main(void) fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } - fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), run_time0, mem_size0/1024.0f/1024.0f, run_time1, mem_size1/1024.0f/1024.0f); From 00a49c2fc1701d91858f31003a878278992dfa03 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 3 Nov 2025 19:49:56 -0500 Subject: [PATCH 056/122] another CI fix --- tests/test-conv2d.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 2dda7e735a..0745121ecf 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -99,6 +99,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } } +#else + GGML_UNUSED(use_gpu); #endif #ifdef GGML_USE_METAL @@ -109,6 +111,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); } } +#else + GGML_UNUSED(use_gpu); #endif if(!model.backend) { From 275c08d25dfc0d730b4bfd48c3570496ff004990 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 4 Nov 2025 15:16:31 -0500 Subject: [PATCH 057/122] add more sd like test cases --- tests/test-conv2d.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 0745121ecf..48dcaa47d8 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -304,6 +304,11 @@ int main(void) std::make_tuple(640,640,104,152,3,3), std::make_tuple(960,320,104,152,3,3), std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(4,320,96,128,3,3), + std::make_tuple(320,4,96,128,3,3), + std::make_tuple(4,320,64,96,3,3), + std::make_tuple(320,4,64,96,3,3), + std::make_tuple(640,640,96,128,3,3), std::make_tuple(320,1280,26,38,3,3), std::make_tuple(1280,1280,26,38,1,1), std::make_tuple(256,128,768,1024,3,3), From 6f44f471133564003649fafc84b7264b28d39317 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 5 Nov 2025 13:04:37 -0500 Subject: [PATCH 058/122] added split-k mode for skinny mnk shapes --- ggml/src/ggml-cuda/conv2d-implicit.cu | 108 +++++++++++++++++-------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 41 ++++++---- tests/test-conv2d.cpp | 12 ++- 3 files changed, 103 insertions(+), 58 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 6bc93b2a57..216f895922 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -13,18 +13,19 @@ constexpr uint WARPSIZE = 32; //currently not use; in future for split-k kernels -// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { -// const int row = blockIdx.x; -// const int col = threadIdx.x; +template +static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x; + const int col = threadIdx.x; -// float sum = 0.0f; -// if (row * blockDim.x + col < ncols) { -// for (int i = 0; i < nrows; ++i){ -// sum += x[i * ncols + row * blockDim.x + col]; -// } -// dst[row * blockDim.x + col] = sum; -// } -// } + float sum = 0.0f; + if (row * blockDim.x + col < ncols) { + for (int i = 0; i < nrows; ++i){ + sum += ggml_cuda_cast(x[i * ncols + row * blockDim.x + col]); + } + dst[row * blockDim.x + col] = ggml_cuda_cast(sum); + } +} template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -705,26 +706,32 @@ __device__ __forceinline__ void ldmatrix_b( } template + const int WK, const int ksplit, const int NUM_THREADS> static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, half * __restrict__ output, const param_t param) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -constexpr unsigned int MMA_M = 16; -constexpr unsigned int MMA_N = 8; - + constexpr unsigned int MMA_M = 16; + constexpr unsigned int MMA_N = 8; const unsigned int K = param.c * param.r * param.s; const uint inChannelOffset = param.c * param.w; - const uint weightKOffset = param.c * param.r * param.s; + const uint weightKOffset = K; // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; - const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; + const unsigned int z = blockIdx.z; + + const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const unsigned int start_k = (ksplit > 0) ? z * ks : 0; + const unsigned int end_k = min(start_k + ks, weightKOffset); + const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + + // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -770,8 +777,8 @@ constexpr unsigned int MMA_N = 8; const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, start_k, end_k, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -781,8 +788,8 @@ constexpr unsigned int MMA_N = 8; if (block_k != num_block_tiles_k){ const half* A_block_gmem = input; const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); half* B_warp_tile = B_block_smem + (warp_n * WN * BK); @@ -813,6 +820,8 @@ constexpr unsigned int MMA_N = 8; } + + if (block_k != num_block_tiles_k) { // switch smem buffers each iteration @@ -863,11 +872,18 @@ constexpr unsigned int MMA_N = 8; const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - if(n < param.n && row < param.k && col < param.Oh * param.Ow){ - const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + if (n < param.n && row < param.k && col < param.Oh * param.Ow) { uint idx = output_lds_addr + subk + j*32*BN/2; idx = idx ^ ((idx & 0b1110000000) >> 4); - output[outOffset] = smemoutput[idx]; + if constexpr (ksplit > 0) { + const uint outOffset = z * param.n * param.k * param.Oh * param.Ow + + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[idx]; + } else { + const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[idx]; + } } } } @@ -952,7 +968,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const half *X_H = input_f16.get(); const half *K_H = kernel_f16.get(); - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; @@ -972,16 +987,41 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - cudaFuncSetAttribute(conv2d_implicit_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 - dim3 gridDim(BlocksN, BlocksM); - dim3 blockDim(ThreadsN, ThreadsM); + const unsigned int K2MN = 8; - conv2d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + if (P.c * P.r * P.s > K2MN * P.n * P.Oh * P.Ow || P.c * P.r * P.s > K2MN * P.k) { + const unsigned int ksplit = 8; + ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); + + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM, ksplit); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; + const unsigned int blockx = (nrows + 511) / 512; + const dim3 block_nums(blockx, 1, 1); + const dim3 block_dims(512, 1, 1); + reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); + + } else { + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + } } else{ conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 347ca12b3e..b242277eb0 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -32,6 +32,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleB( const half* src, half* dst, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -57,9 +59,9 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -68,7 +70,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -91,7 +93,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( const half* src, half* dst, - // const unsigned int src_stride, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inChannelOffset, param_t param ) @@ -128,9 +131,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA( int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w // apply swizzle to the dst index @@ -138,7 +141,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; } else{ @@ -164,6 +167,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( float4 (&dst_reg)[ELEMENTS_PER_THREAD], // const unsigned int src_stride, const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inChannelOffset, param_t param ){ @@ -194,13 +199,13 @@ __device__ __forceinline__ void tileMemcpyLoadA( int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; } else{ @@ -227,6 +232,8 @@ __device__ __forceinline__ void tileMemcpyLoadB( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -249,14 +256,14 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 48dcaa47d8..b5e7b18a2a 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -309,7 +309,6 @@ int main(void) std::make_tuple(4,320,64,96,3,3), std::make_tuple(320,4,64,96,3,3), std::make_tuple(640,640,96,128,3,3), - std::make_tuple(320,1280,26,38,3,3), std::make_tuple(1280,1280,26,38,1,1), std::make_tuple(256,128,768,1024,3,3), std::make_tuple(128,3,768,1024,3,3), @@ -385,14 +384,13 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // // float diff = fabs(conv2d_data[i] - wino_data[i]); - // float diff = fabs(im2col_data[i] - wino_data[i]); - // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); + // // for(int i = 26*38; i < 2*26*38; i++) { + // // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", + // printf("(%7.3f, %7.3f, %.2f, %d) \n", // im2col_data[i], conv2d_data[i], - // wino_data[i], diff, diff1, i); + // diff, i); // // break; // // } // } From 688de6d7d8029ce2463d7b16639cab00c1b4e2eb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 5 Nov 2025 13:47:38 -0500 Subject: [PATCH 059/122] fixed bug now split-k is working --- ggml/src/ggml-cuda/conv2d-implicit.cu | 3 --- ggml/src/ggml-cuda/conv2d-implicit.cuh | 4 ++-- tests/test-conv2d.cpp | 6 +++--- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 216f895922..fe55de4b91 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -819,9 +819,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } - - - if (block_k != num_block_tiles_k) { // switch smem buffers each iteration diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index b242277eb0..0226e715ce 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -66,7 +66,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); @@ -262,7 +262,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8; if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index b5e7b18a2a..87abd015dc 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -44,7 +44,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, // create data int KW = kw, KH = kh, IC = ic, OC = oc; int IW = iw, IH = ih, N = 1; - srand(time(NULL)); + // srand(time(NULL)); // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); @@ -384,8 +384,8 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // // for(int i = 26*38; i < 2*26*38; i++) { - // // for(int i = 0; i < conv2d_data.size(); i++) { + // for(int i = 26*38; i < 2*26*38; i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { // printf("(%7.3f, %7.3f, %.2f, %d) \n", From d9a48580fc34de5b1e2a2ce21efb337d753080cc Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 5 Nov 2025 13:58:25 -0500 Subject: [PATCH 060/122] use a better criterian to use split-k --- ggml/src/ggml-cuda/conv2d-implicit.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index fe55de4b91..d2d775c9b2 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -984,9 +984,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - const unsigned int K2MN = 8; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - if (P.c * P.r * P.s > K2MN * P.n * P.Oh * P.Ow || P.c * P.r * P.s > K2MN * P.k) { + if (BlocksM * BlocksN < nsm) { const unsigned int ksplit = 8; ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); From 09e3a5f07d72978f920a29b3765fc42fb3880cdf Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 5 Nov 2025 22:02:57 -0500 Subject: [PATCH 061/122] try to reduce index calculation --- ggml/src/ggml-cuda/conv2d-implicit.cu | 26 ++++++++----- ggml/src/ggml-cuda/conv2d-implicit.cuh | 54 +++++++++++++++----------- tests/test-conv2d.cpp | 3 +- 3 files changed, 49 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d2d775c9b2..9b2331876b 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -720,6 +720,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint inChannelOffset = param.c * param.w; const uint weightKOffset = K; + const unsigned int PQ = param.Ow * param.Oh; + const unsigned int KPQ = param.k * PQ; + const unsigned int NKPQ = param.n * KPQ; + // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; @@ -845,14 +849,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (int i = 0; i < 2; ++i) { __syncthreads(); - +#pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { + const int output_sts_offset = output_sts_addr + mma_m * MMA_M * BN / 2 - i * mma_tiles_per_warp_n/2 * MMA_N; for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint idx = output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx = output_sts_offset + mma_n * MMA_N; + // mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; @@ -861,24 +866,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } __syncthreads(); - + const unsigned int m_i_wn = m_idx + i * WN / 2; #pragma unroll for (int subk = 0; subk < WN / 2; ++subk){ + const uint row = m_i_wn + subk; +#pragma unroll for (int j = 0; j < 4; ++j){ - const uint row = m_idx + subk + i * WN / 2; const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - if (n < param.n && row < param.k && col < param.Oh * param.Ow) { + if (n < param.n && row < param.k && col < PQ) { uint idx = output_lds_addr + subk + j*32*BN/2; idx = idx ^ ((idx & 0b1110000000) >> 4); if constexpr (ksplit > 0) { - const uint outOffset = z * param.n * param.k * param.Oh * param.Ow + - n * param.k * param.Oh * param.Ow + - row * param.Oh * param.Ow + col; + const uint outOffset = z * NKPQ + + n * KPQ + + row * PQ + col; output[outOffset] = smemoutput[idx]; } else { - const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + const uint outOffset = n * KPQ + row * PQ + col; output[outOffset] = smemoutput[idx]; } } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 0226e715ce..a49210ddd8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -59,18 +59,20 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + + const unsigned int ki = start_k+thread_col*8; + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + ki; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -122,6 +124,12 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + const unsigned int ki = start_k+thread_col*8; + const unsigned int chw = param.c * param.h * param.w; + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -130,10 +138,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // unsigned int inOffset = n * param.c * param.h * param.w; int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w // apply swizzle to the dst index @@ -141,9 +146,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ + curR < param.r && curS < param.s && curC < param.c && ki < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -191,6 +196,13 @@ __device__ __forceinline__ void tileMemcpyLoadA( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + const unsigned int ki = start_k+block_k+thread_col*8; + const unsigned int chw = param.c * param.h * param.w; + + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; @@ -198,16 +210,13 @@ __device__ __forceinline__ void tileMemcpyLoadA( unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // unsigned int inOffset = n * param.c * param.h * param.w; int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ + curR < param.r && curS < param.s && curC < param.c && ki < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; } else{ dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -256,14 +265,15 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int ki = start_k+block_k+thread_col*8; + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8; - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ + const unsigned int src_index = thread_row * src_stride + ki; + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 87abd015dc..e5da8ab056 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -384,8 +384,7 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 26*38; i < 2*26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { + // // for(int i = 0; i < conv2d_data.size(); i++) { // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { // printf("(%7.3f, %7.3f, %.2f, %d) \n", From 68ccd2a899e7c50443e7ea63056b9bbbf73372ca Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 6 Nov 2025 09:54:01 -0500 Subject: [PATCH 062/122] refactor cuda core code path --- ggml/src/ggml-cuda/conv2d-implicit.cu | 214 ++----------------------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 133 +++++++++++++++ 2 files changed, 150 insertions(+), 197 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 9b2331876b..1e6eee88c9 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -27,6 +27,7 @@ static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restri } } + template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -63,6 +64,8 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } + + template){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; - smemweight[weight_sts_addr + offset + 0] = tmp.x; - smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; - smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; - smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; - }else{ // read 4 halves - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; - const half *val = reinterpret_cast(&tmp); - smemweight[weight_sts_addr + offset + 0] = val[0]; - smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; - smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; - smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i){ - smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - } - }else{ -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){ - smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i]; - } else { - smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - } - } - } + loadFilter + (kernel, smemweight, by, innerRowA, innerColA, weightKOffset, + start_k, end_k, param); + loadInput + (input, smeminput, bx, innerRowA, innerColA, + start_k, end_k, PQ, CHW, inChannelOffset, param); - const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; -#pragma unroll - for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; - if(vec_load){ - const uint cur0 = fastdiv(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; - smeminput[input_sts_addr + offset + 0] = tmp.x; - smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; - smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; - smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; - } else { -#pragma unroll - for (int i = 0; i < 4; ++i) - smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; - } else { - smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } - } - } __syncthreads(); // lds @@ -279,106 +190,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } // ldg -#pragma unroll - for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { - if(vec_load){ - if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){ - if constexpr (std::is_same_v){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; - } else { - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; - const half *val = reinterpret_cast(&tmp); - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i) - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - }else{ -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){ - // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i]; - } else { - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - } - } - } -#pragma unroll - for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; - if(vec_load){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ - // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; - } else { -#pragma unroll - for (int i = 0; i < 4; ++i) - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; + loadFilter + (kernel, &smemweight[write_flag * (BN+PAD) * BK], by, innerRowA, innerColA, weightKOffset, + crs+BK, end_k, param); + + loadInput + (input, &smeminput[write_flag * (BM+PAD) * BK], bx, innerRowA, innerColA, + crs + BK, end_k, PQ, CHW, inChannelOffset, param); - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; - } else { - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } - } - } __syncthreads(); write_flag ^= 1; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index a49210ddd8..85936e42c6 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -359,6 +359,139 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { return address; } +template +__device__ __forceinline__ void loadFilter(const T * __restrict__ kernel, + T * __restrict__ smemweight, + const unsigned int by, + const unsigned int innerRowA, + const unsigned int innerColA, + const unsigned int weightKOffset, + const unsigned int start_k, + const unsigned int end_k, + const param_t param){ + + const unsigned int weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; + const unsigned int kidx = start_k + innerColA * 4; +#pragma unroll + for (int offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + const unsigned int nidx = by * BN + innerRowA + offset; + if (vec_load) { + if (nidx < param.k && kidx < end_k) { + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[nidx * weightKOffset + kidx])[0]; + smemweight[weight_sts_addr + offset + 0] = tmp.x; + smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + } else { // read 4 halves + float2 tmp = reinterpret_cast(&kernel[nidx * weightKOffset + kidx])[0]; + const half *val = reinterpret_cast(&tmp); + smemweight[weight_sts_addr + offset + 0] = val[0]; + smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + if (nidx < param.k && kidx + i < end_k) { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[nidx * weightKOffset + kidx + i]; + } else { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } + } +} + + +template +__device__ __forceinline__ void loadInput(const float * __restrict__ input, + float * __restrict__ smeminput, + const unsigned int bx, + const unsigned int innerRowA, + const unsigned int innerColA, + const unsigned int start_k, + const unsigned int end_k, + const unsigned int PQ, + const unsigned int CHW, + const unsigned int inChannelOffset, + const param_t param) { + const unsigned int input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; + const unsigned int kidx = start_k + innerColA * 4; +#pragma unroll + for (unsigned int offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const unsigned int midx = bx * BM + innerRowA + offset; + int n = (ksplit > 0) ? midx / PQ : blockIdx.z; + const unsigned int npq_res = midx % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.v - param.q; + const unsigned int inOffset = n * CHW; + if (vec_load) { + const unsigned int cur0 = fastdiv(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const unsigned int cur1 = fastdiv(fastmodulo(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int cur2 = fastmodulo(fastmodulo(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int curC = layout == 0 ? cur2 : cur0; + const unsigned int curR = layout == 0 ? cur0 : cur1; + const unsigned int curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx < end_k) { + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[input_sts_addr + offset + 0] = tmp.x; + smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; + smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; + smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + const unsigned int cur0 = fastdiv(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const unsigned int cur1 = fastdiv(fastmodulo(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int cur2 = fastmodulo(fastmodulo(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int curC = layout == 0 ? cur2 : cur0; + const unsigned int curR = layout == 0 ? cur0 : cur1; + const unsigned int curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx + i < end_k) { + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; + } else { + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } + } + } +} + #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 311213d209f76bb132a83fe7909196349eaa285e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 6 Nov 2025 10:21:49 -0500 Subject: [PATCH 063/122] make sure there are enough channels for split-k --- ggml/src/ggml-cuda/conv2d-implicit.cu | 5 ++--- tests/test-conv2d.cpp | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 9b2331876b..fc3d25dfc8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -991,9 +991,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - - if (BlocksM * BlocksN < nsm) { - const unsigned int ksplit = 8; + const unsigned int ksplit = 8; + if (BlocksM * BlocksN < nsm && P.c > 8 * ksplit) { ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); cudaFuncSetAttribute(conv2d_implicit_kernel, diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index e5da8ab056..41807a6b80 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -324,6 +324,7 @@ int main(void) std::make_tuple(256,128,832,1216,3,3), std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) + // std::make_tuple(32,64,58,58,3,3) }; int k = 0; From ba70ad8e59fc66fc76044ea7f0f3b27b69ceb0c7 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 6 Nov 2025 20:35:37 -0500 Subject: [PATCH 064/122] added test cases exactly replicating sdxl unet steps --- tests/test-conv2d.cpp | 372 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 347 insertions(+), 25 deletions(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 41807a6b80..4bb6c692ba 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -297,34 +297,351 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr int main(void) { ggml_time_init(); + + double time_iter0 = 0.0, time_iter1 = 0.0; std::vector> configs = { - std::make_tuple(64,64,48,64,3,3), - std::make_tuple(320,320,104,152,3,3), - std::make_tuple(640,640,52,76,3,3), - std::make_tuple(640,640,104,152,3,3), - std::make_tuple(960,320,104,152,3,3), - std::make_tuple(1280,1280,26,38,3,3), - std::make_tuple(4,320,96,128,3,3), - std::make_tuple(320,4,96,128,3,3), - std::make_tuple(4,320,64,96,3,3), - std::make_tuple(320,4,64,96,3,3), - std::make_tuple(640,640,96,128,3,3), - std::make_tuple(1280,1280,26,38,1,1), - std::make_tuple(256,128,768,1024,3,3), - std::make_tuple(128,3,768,1024,3,3), - std::make_tuple(256,128,768,1024,1,1), - std::make_tuple(512,256,384,512,1,1), - std::make_tuple(1280,640,52,76,3,3), - std::make_tuple(1920,1280,26,38,3,3), - std::make_tuple(2560,1280,26,38,3,3), - std::make_tuple(320,1280,26,38,3,3), - std::make_tuple(512,512,104,152,3,3), - std::make_tuple(512,512,208,304,3,3), - std::make_tuple(512,256,416,608,3,3), - std::make_tuple(256,128,832,1216,3,3), - std::make_tuple(256,256,832,1216,3,3), + // std::make_tuple(64,64,48,64,3,3), + // std::make_tuple(320,320,104,152,3,3), + // std::make_tuple(640,640,52,76,3,3), + // std::make_tuple(640,640,104,152,3,3), + // std::make_tuple(960,320,104,152,3,3), + // std::make_tuple(1280,1280,26,38,3,3), + // std::make_tuple(4,320,96,128,3,3), + // std::make_tuple(320,4,96,128,3,3), + // std::make_tuple(4,320,64,96,3,3), + // std::make_tuple(320,4,64,96,3,3), + // std::make_tuple(640,640,96,128,3,3), + // std::make_tuple(1280,1280,26,38,1,1), + // std::make_tuple(256,128,768,1024,3,3), + // std::make_tuple(128,3,768,1024,3,3), + // std::make_tuple(256,128,768,1024,1,1), + // std::make_tuple(512,256,384,512,1,1), + // std::make_tuple(1280,640,52,76,3,3), + // std::make_tuple(1920,1280,26,38,3,3), + // std::make_tuple(2560,1280,26,38,3,3), + // std::make_tuple(320,1280,26,38,3,3), + // std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,208,304,3,3), + // std::make_tuple(512,256,416,608,3,3), + // std::make_tuple(256,128,832,1216,3,3), + // std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) // std::make_tuple(32,64,58,58,3,3) + + //512x512 + std::make_tuple(4,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,4,64,64,3,3), + std::make_tuple(4,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,4,64,64,3,3), + + //768x768 + // std::make_tuple(4,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(320,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(640,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(640,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(1920,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1920,1280,24,24,3,3), + // std::make_tuple(1280,1280,48,48,3,3), + // std::make_tuple(1920,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(1920,640,48,48,3,3), + // std::make_tuple(1280,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(1280,640,48,48,3,3), + // std::make_tuple(960,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(960,640,48,48,3,3), + // std::make_tuple(640,640,96,96,3,3), + // std::make_tuple(960,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(960,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(320,4,96,96,3,3), + // std::make_tuple(4,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(320,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(320,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(640,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(640,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(2560,1280,24,24,3,3), + // std::make_tuple(1920,1280,24,24,3,3), + // std::make_tuple(1280,1280,24,24,3,3), + // std::make_tuple(1920,1280,24,24,3,3), + // std::make_tuple(1280,1280,48,48,3,3), + // std::make_tuple(1920,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(1920,640,48,48,3,3), + // std::make_tuple(1280,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(1280,640,48,48,3,3), + // std::make_tuple(960,640,48,48,3,3), + // std::make_tuple(640,640,48,48,3,3), + // std::make_tuple(960,640,48,48,3,3), + // std::make_tuple(640,640,96,96,3,3), + // std::make_tuple(960,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(960,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(320,320,96,96,3,3), + // std::make_tuple(640,320,96,96,3,3), + // std::make_tuple(320,4,96,96,3,3), + + + //1024x1024 + // std::make_tuple(4,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,4,128,128,3,3), + // std::make_tuple(4,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,4,128,128,3,3), + + }; int k = 0; @@ -377,6 +694,10 @@ int main(void) fprintf(stderr, "| --- | --- | --- | --- | --- \n"); } + time_iter0 += run_time0; + time_iter1 += run_time1; + + fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), run_time0, mem_size0/1024.0f/1024.0f, @@ -401,6 +722,7 @@ int main(void) ggml_gallocr_free(allocr); } + printf("| 1 unet iter takes| %.2f ms | | %.2f ms | \n", time_iter0, time_iter1); // printf("\nPerforming test:\n"); return 0; From 4e9ebe92e0be60cedad2707ed2678d98ee90ef61 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 6 Nov 2025 22:31:28 -0500 Subject: [PATCH 065/122] minor update --- tests/test-conv2d.cpp | 414 +++++++++++++++++++++--------------------- 1 file changed, 207 insertions(+), 207 deletions(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 4bb6c692ba..f671c4606a 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -329,108 +329,108 @@ int main(void) // std::make_tuple(32,64,58,58,3,3) //512x512 - std::make_tuple(4,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,4,64,64,3,3), - std::make_tuple(4,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,4,64,64,3,3), + // std::make_tuple(4,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(320,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(640,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(640,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(1920,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1920,1280,16,16,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(1920,640,32,32,3,3), + // std::make_tuple(1280,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(1280,640,32,32,3,3), + // std::make_tuple(960,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(960,640,32,32,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(960,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(320,4,64,64,3,3), + // std::make_tuple(4,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(320,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(320,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(640,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(640,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(2560,1280,16,16,3,3), + // std::make_tuple(1920,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1920,1280,16,16,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(1920,640,32,32,3,3), + // std::make_tuple(1280,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(1280,640,32,32,3,3), + // std::make_tuple(960,640,32,32,3,3), + // std::make_tuple(640,640,32,32,3,3), + // std::make_tuple(960,640,32,32,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(960,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(320,320,64,64,3,3), + // std::make_tuple(640,320,64,64,3,3), + // std::make_tuple(320,4,64,64,3,3), //768x768 // std::make_tuple(4,320,96,96,3,3), @@ -538,108 +538,108 @@ int main(void) //1024x1024 - // std::make_tuple(4,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,4,128,128,3,3), - // std::make_tuple(4,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,4,128,128,3,3), + std::make_tuple(4,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,4,128,128,3,3), + std::make_tuple(4,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,4,128,128,3,3), }; @@ -690,15 +690,15 @@ int main(void) if(k==0) { k = 1; - fprintf(stderr, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); - fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + fprintf(stdout, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stdout, "| --- | --- | --- | --- | --- \n"); } time_iter0 += run_time0; time_iter1 += run_time1; - fprintf(stderr, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + fprintf(stdout, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), run_time0, mem_size0/1024.0f/1024.0f, run_time1, mem_size1/1024.0f/1024.0f); From df88b2c91746ad198c556e942b8300c4806eddbd Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 15:38:36 -0500 Subject: [PATCH 066/122] trying to get rid of remaining bank conflicts; also fixed a bug for split-k condition check --- ggml/src/ggml-cuda/conv2d-implicit.cu | 38 +++-- tests/test-backend-ops.cpp | 8 + tests/test-conv2d.cpp | 228 +++++++++++++------------- 3 files changed, 150 insertions(+), 124 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 64a521f616..5307e58ed7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -672,12 +672,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { - const int output_sts_offset = output_sts_addr + mma_m * MMA_M * BN / 2 - i * mma_tiles_per_warp_n/2 * MMA_N; for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint idx = output_sts_offset + mma_n * MMA_N; - // mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx = output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; @@ -688,24 +687,40 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, __syncthreads(); const unsigned int m_i_wn = m_idx + i * WN / 2; #pragma unroll - for (int subk = 0; subk < WN / 2; ++subk){ - const uint row = m_i_wn + subk; + for (int subk = 0; subk < WN / 4; ++subk){ + const uint row = m_i_wn + subk*2; #pragma unroll for (int j = 0; j < 4; ++j){ const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + uint idx = output_lds_addr + subk*2 + j*32*BN/2; + idx = idx ^ ((idx & 0b1110000000) >> 4); + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); if (n < param.n && row < param.k && col < PQ) { - uint idx = output_lds_addr + subk + j*32*BN/2; - idx = idx ^ ((idx & 0b1110000000) >> 4); if constexpr (ksplit > 0) { const uint outOffset = z * NKPQ + n * KPQ + row * PQ + col; - output[outOffset] = smemoutput[idx]; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[0]; } else { const uint outOffset = n * KPQ + row * PQ + col; - output[outOffset] = smemoutput[idx]; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[0]; + } + } + if (n < param.n && row+1 < param.k && col < PQ) { + if constexpr (ksplit > 0) { + const uint outOffset = z * NKPQ + + n * KPQ + + (row+1) * PQ + col; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[1]; + } else { + const uint outOffset = n * KPQ + (row+1) * PQ + col; + // output[outOffset] = smemoutput[idx]; + output[outOffset] = reinterpret_cast(dst_ptr)[1]; } } } @@ -803,6 +818,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + + static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free"); + const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; @@ -812,7 +830,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; const unsigned int ksplit = 8; - if (BlocksM * BlocksN < nsm && P.c > 8 * ksplit) { + if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); cudaFuncSetAttribute(conv2d_implicit_kernel, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 177288c811..16861c71c9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5848,6 +5848,14 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_conv_2d( { 24, 24, 32, 1 }, { 3, 3, 32, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 96, 1 }, { 3, 3, 96, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 1 }, { 3, 3, 128, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index f671c4606a..0b1b5c476f 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -305,7 +305,7 @@ int main(void) // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), // std::make_tuple(4,320,64,96,3,3), @@ -538,108 +538,108 @@ int main(void) //1024x1024 - std::make_tuple(4,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,4,128,128,3,3), - std::make_tuple(4,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,4,128,128,3,3), + // std::make_tuple(4,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,4,128,128,3,3), + // std::make_tuple(4,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(320,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(640,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(2560,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,32,32,3,3), + // std::make_tuple(1920,1280,32,32,3,3), + // std::make_tuple(1280,1280,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1920,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(1280,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(960,640,64,64,3,3), + // std::make_tuple(640,640,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(960,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,320,128,128,3,3), + // std::make_tuple(640,320,128,128,3,3), + // std::make_tuple(320,4,128,128,3,3), }; @@ -663,7 +663,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -705,16 +705,16 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - // // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < 26*38; i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 76885c769703478cbba8a2e1f2809c144181fae5 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 17:44:00 -0500 Subject: [PATCH 067/122] WIP: debugging --- ggml/src/ggml-cuda/conv2d-implicit.cu | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 5307e58ed7..b02224fc06 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -677,6 +677,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; @@ -695,19 +696,24 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); uint idx = output_lds_addr + subk*2 + j*32*BN/2; + idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); - uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); + // uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); + uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx])); + half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { if constexpr (ksplit > 0) { const uint outOffset = z * NKPQ + n * KPQ + row * PQ + col; // output[outOffset] = smemoutput[idx]; - output[outOffset] = reinterpret_cast(dst_ptr)[0]; + // output[outOffset] = reinterpret_cast(dst_ptr)[0]; + output[outOffset] = res_[0]; } else { const uint outOffset = n * KPQ + row * PQ + col; // output[outOffset] = smemoutput[idx]; - output[outOffset] = reinterpret_cast(dst_ptr)[0]; + // output[outOffset] = reinterpret_cast(dst_ptr)[0]; + output[outOffset] = res_[0]; } } if (n < param.n && row+1 < param.k && col < PQ) { @@ -716,11 +722,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, n * KPQ + (row+1) * PQ + col; // output[outOffset] = smemoutput[idx]; - output[outOffset] = reinterpret_cast(dst_ptr)[1]; + // output[outOffset] = reinterpret_cast(dst_ptr)[1]; + output[outOffset] = res_[1]; } else { const uint outOffset = n * KPQ + (row+1) * PQ + col; // output[outOffset] = smemoutput[idx]; - output[outOffset] = reinterpret_cast(dst_ptr)[1]; + // output[outOffset] = reinterpret_cast(dst_ptr)[1]; + output[outOffset] = res_[1]; } } } From 949eca4cba3e08afcbfbf48c8d288cbf188b9bb1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 19:20:12 -0500 Subject: [PATCH 068/122] swizzling working, may still have room to optimize --- ggml/src/ggml-cuda/conv2d-implicit.cu | 14 ++++---------- tests/test-conv2d.cpp | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index b02224fc06..3a84935582 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -677,11 +677,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx8 = idx + 8 * BN / 2; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; - dst_ptr = reinterpret_cast(&smemoutput[idx + 8 * BN / 2]); + idx8 = idx8 ^ ((idx8 & 0b110000000000) >> 9); + idx8 = idx8 ^ ((idx8 & 0b1110000000) >> 4); + dst_ptr = reinterpret_cast(&smemoutput[idx8]); dst_ptr[0] = reg_[1]; } } @@ -698,7 +701,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, uint idx = output_lds_addr + subk*2 + j*32*BN/2; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); - // uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx])); half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { @@ -706,13 +708,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint outOffset = z * NKPQ + n * KPQ + row * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[0]; output[outOffset] = res_[0]; } else { const uint outOffset = n * KPQ + row * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[0]; output[outOffset] = res_[0]; } } @@ -721,13 +719,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint outOffset = z * NKPQ + n * KPQ + (row+1) * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[1]; output[outOffset] = res_[1]; } else { const uint outOffset = n * KPQ + (row+1) * PQ + col; - // output[outOffset] = smemoutput[idx]; - // output[outOffset] = reinterpret_cast(dst_ptr)[1]; output[outOffset] = res_[1]; } } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 0b1b5c476f..75778b6e30 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -705,16 +705,16 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { - for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } + // for(int i = 0; i < 26*38; i++) { + // // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 8809af79a8a14c9b98d86ba2c0af2e7d541f8405 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 22:11:21 -0500 Subject: [PATCH 069/122] now bank conflicts free and performance get a bit boosted too --- ggml/src/ggml-cuda/conv2d-implicit.cu | 6 +- tests/test-conv2d.cpp | 676 +++++++++++++------------- 2 files changed, 341 insertions(+), 341 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 3a84935582..2fd244389d 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -677,14 +677,12 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); uint idx = output_sts_addr + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; - uint idx8 = idx + 8 * BN / 2; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; - idx8 = idx8 ^ ((idx8 & 0b110000000000) >> 9); - idx8 = idx8 ^ ((idx8 & 0b1110000000) >> 4); - dst_ptr = reinterpret_cast(&smemoutput[idx8]); + idx = (idx + 8 * BN / 2 ) ^ 0b010; + dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[1]; } } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 75778b6e30..720ddbf269 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -300,353 +300,355 @@ int main(void) double time_iter0 = 0.0, time_iter1 = 0.0; std::vector> configs = { - // std::make_tuple(64,64,48,64,3,3), - // std::make_tuple(320,320,104,152,3,3), - // std::make_tuple(640,640,52,76,3,3), - // std::make_tuple(640,640,104,152,3,3), - // std::make_tuple(960,320,104,152,3,3), + std::make_tuple(64,64,48,64,3,3), + std::make_tuple(320,320,104,152,3,3), + std::make_tuple(640,640,52,76,3,3), + std::make_tuple(640,640,104,152,3,3), + std::make_tuple(960,320,104,152,3,3), std::make_tuple(1280,1280,26,38,3,3), - // std::make_tuple(4,320,96,128,3,3), - // std::make_tuple(320,4,96,128,3,3), - // std::make_tuple(4,320,64,96,3,3), - // std::make_tuple(320,4,64,96,3,3), - // std::make_tuple(640,640,96,128,3,3), - // std::make_tuple(1280,1280,26,38,1,1), - // std::make_tuple(256,128,768,1024,3,3), - // std::make_tuple(128,3,768,1024,3,3), - // std::make_tuple(256,128,768,1024,1,1), - // std::make_tuple(512,256,384,512,1,1), - // std::make_tuple(1280,640,52,76,3,3), - // std::make_tuple(1920,1280,26,38,3,3), - // std::make_tuple(2560,1280,26,38,3,3), - // std::make_tuple(320,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), - // std::make_tuple(512,512,208,304,3,3), - // std::make_tuple(512,256,416,608,3,3), - // std::make_tuple(256,128,832,1216,3,3), - // std::make_tuple(256,256,832,1216,3,3), + std::make_tuple(4,320,96,128,3,3), + std::make_tuple(320,4,96,128,3,3), + std::make_tuple(4,320,64,96,3,3), + std::make_tuple(320,4,64,96,3,3), + std::make_tuple(640,640,96,128,3,3), + std::make_tuple(1280,1280,26,38,1,1), + std::make_tuple(256,128,768,1024,3,3), + std::make_tuple(128,3,768,1024,3,3), + std::make_tuple(256,128,768,1024,1,1), + std::make_tuple(512,256,384,512,1,1), + std::make_tuple(1280,640,52,76,3,3), + std::make_tuple(1920,1280,26,38,3,3), + std::make_tuple(2560,1280,26,38,3,3), + std::make_tuple(320,1280,26,38,3,3), + std::make_tuple(512,512,104,152,3,3), + std::make_tuple(512,512,208,304,3,3), + std::make_tuple(512,256,416,608,3,3), + std::make_tuple(256,128,832,1216,3,3), + std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) - // std::make_tuple(32,64,58,58,3,3) - + std::make_tuple(32,64,58,58,3,3) + }; + std::vector> configs_sdxl_512 = { //512x512 - // std::make_tuple(4,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(320,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(640,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(640,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(1920,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1920,1280,16,16,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1920,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(1920,640,32,32,3,3), - // std::make_tuple(1280,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(1280,640,32,32,3,3), - // std::make_tuple(960,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(960,640,32,32,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(960,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(960,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(320,4,64,64,3,3), - // std::make_tuple(4,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(320,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(320,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(640,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(640,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(2560,1280,16,16,3,3), - // std::make_tuple(1920,1280,16,16,3,3), - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(1920,1280,16,16,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1920,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(1920,640,32,32,3,3), - // std::make_tuple(1280,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(1280,640,32,32,3,3), - // std::make_tuple(960,640,32,32,3,3), - // std::make_tuple(640,640,32,32,3,3), - // std::make_tuple(960,640,32,32,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(960,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(960,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(320,320,64,64,3,3), - // std::make_tuple(640,320,64,64,3,3), - // std::make_tuple(320,4,64,64,3,3), + std::make_tuple(4,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,4,64,64,3,3), + std::make_tuple(4,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(320,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(640,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(2560,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1920,1280,16,16,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1920,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(1280,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,32,32,3,3), + std::make_tuple(960,640,32,32,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(960,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,320,64,64,3,3), + std::make_tuple(640,320,64,64,3,3), + std::make_tuple(320,4,64,64,3,3) + }; + std::vector> configs_sdxl_768 = { //768x768 - // std::make_tuple(4,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(320,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(640,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(640,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(1920,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1920,1280,24,24,3,3), - // std::make_tuple(1280,1280,48,48,3,3), - // std::make_tuple(1920,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(1920,640,48,48,3,3), - // std::make_tuple(1280,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(1280,640,48,48,3,3), - // std::make_tuple(960,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(960,640,48,48,3,3), - // std::make_tuple(640,640,96,96,3,3), - // std::make_tuple(960,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(960,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(320,4,96,96,3,3), - // std::make_tuple(4,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(320,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(320,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(640,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(640,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(2560,1280,24,24,3,3), - // std::make_tuple(1920,1280,24,24,3,3), - // std::make_tuple(1280,1280,24,24,3,3), - // std::make_tuple(1920,1280,24,24,3,3), - // std::make_tuple(1280,1280,48,48,3,3), - // std::make_tuple(1920,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(1920,640,48,48,3,3), - // std::make_tuple(1280,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(1280,640,48,48,3,3), - // std::make_tuple(960,640,48,48,3,3), - // std::make_tuple(640,640,48,48,3,3), - // std::make_tuple(960,640,48,48,3,3), - // std::make_tuple(640,640,96,96,3,3), - // std::make_tuple(960,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(960,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(320,320,96,96,3,3), - // std::make_tuple(640,320,96,96,3,3), - // std::make_tuple(320,4,96,96,3,3), - + std::make_tuple(4,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(320,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(640,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(640,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(1920,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1920,1280,24,24,3,3), + std::make_tuple(1280,1280,48,48,3,3), + std::make_tuple(1920,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(1920,640,48,48,3,3), + std::make_tuple(1280,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(1280,640,48,48,3,3), + std::make_tuple(960,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(960,640,48,48,3,3), + std::make_tuple(640,640,96,96,3,3), + std::make_tuple(960,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(960,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(320,4,96,96,3,3), + std::make_tuple(4,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(320,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(320,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(640,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(640,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(2560,1280,24,24,3,3), + std::make_tuple(1920,1280,24,24,3,3), + std::make_tuple(1280,1280,24,24,3,3), + std::make_tuple(1920,1280,24,24,3,3), + std::make_tuple(1280,1280,48,48,3,3), + std::make_tuple(1920,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(1920,640,48,48,3,3), + std::make_tuple(1280,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(1280,640,48,48,3,3), + std::make_tuple(960,640,48,48,3,3), + std::make_tuple(640,640,48,48,3,3), + std::make_tuple(960,640,48,48,3,3), + std::make_tuple(640,640,96,96,3,3), + std::make_tuple(960,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(960,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(320,320,96,96,3,3), + std::make_tuple(640,320,96,96,3,3), + std::make_tuple(320,4,96,96,3,3), + }; + std::vector> configs_sdxl_1024 = { //1024x1024 - // std::make_tuple(4,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,4,128,128,3,3), - // std::make_tuple(4,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(320,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(640,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(2560,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,32,32,3,3), - // std::make_tuple(1920,1280,32,32,3,3), - // std::make_tuple(1280,1280,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1920,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(1280,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(960,640,64,64,3,3), - // std::make_tuple(640,640,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(960,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,320,128,128,3,3), - // std::make_tuple(640,320,128,128,3,3), - // std::make_tuple(320,4,128,128,3,3), - - + std::make_tuple(4,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,4,128,128,3,3), + std::make_tuple(4,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(320,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(640,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(2560,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,32,32,3,3), + std::make_tuple(1920,1280,32,32,3,3), + std::make_tuple(1280,1280,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1920,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(1280,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,64,64,3,3), + std::make_tuple(960,640,64,64,3,3), + std::make_tuple(640,640,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(960,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,320,128,128,3,3), + std::make_tuple(640,320,128,128,3,3), + std::make_tuple(320,4,128,128,3,3) }; int k = 0; - for (auto c : configs){ + for (auto c : configs_sdxl_1024){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); @@ -663,7 +665,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 0; + int iterations = 20; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); From 414bb8d9ed02050c6f37b4cfe57ce70a78434bfe Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 23:20:46 -0500 Subject: [PATCH 070/122] further reduce index swizzling computation cycles --- ggml/src/ggml-cuda/conv2d-implicit.cu | 11 +++++++---- tests/test-conv2d.cpp | 5 +++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 2fd244389d..0ec9dca1bd 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -691,15 +691,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #pragma unroll for (int subk = 0; subk < WN / 4; ++subk){ const uint row = m_i_wn + subk*2; + uint idx = output_lds_addr + subk*2; + idx = idx ^ ((idx & 0b110000000000) >> 9); + idx = idx ^ ((idx & 0b1110000000) >> 4); #pragma unroll for (int j = 0; j < 4; ++j){ const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - uint idx = output_lds_addr + subk*2 + j*32*BN/2; - idx = idx ^ ((idx & 0b110000000000) >> 9); - idx = idx ^ ((idx & 0b1110000000) >> 4); - uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx])); + // uint idx = output_lds_addr + subk*2 + j*32*BN/2; + // idx = idx ^ ((idx & 0b110000000000) >> 9); + // idx = idx ^ ((idx & 0b1110000000) >> 4); + uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { if constexpr (ksplit > 0) { diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 720ddbf269..57edc02474 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -325,8 +325,8 @@ int main(void) std::make_tuple(512,256,416,608,3,3), std::make_tuple(256,128,832,1216,3,3), std::make_tuple(256,256,832,1216,3,3), - // std::make_tuple(320,256,1024,1920) std::make_tuple(32,64,58,58,3,3) + // std::make_tuple(320,256,1024,1920) }; std::vector> configs_sdxl_512 = { //512x512 @@ -648,7 +648,8 @@ int main(void) int k = 0; - for (auto c : configs_sdxl_1024){ + for (auto c : configs_sdxl_512){ + // for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); From 64ead3fd4fa7662360482cd07b3e16308cb3433f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 7 Nov 2025 23:21:30 -0500 Subject: [PATCH 071/122] remove commented code --- ggml/src/ggml-cuda/conv2d-implicit.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 0ec9dca1bd..33f5ac23a7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -699,9 +699,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - // uint idx = output_lds_addr + subk*2 + j*32*BN/2; - // idx = idx ^ ((idx & 0b110000000000) >> 9); - // idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { From 9cbc099493b0a9f68d15e2ab597b1824611e99ce Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 8 Nov 2025 14:51:45 -0500 Subject: [PATCH 072/122] broken for some test cases --- ggml/src/ggml-cuda/conv2d-implicit.cu | 117 +++++++++++++++++++------- tests/test-conv2d.cpp | 79 ++++++++--------- 2 files changed, 127 insertions(+), 69 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 33f5ac23a7..1133626d14 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -779,6 +779,33 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, WNITER, TM, TN, NUM_THREADS, 1, false, 0><<>>(X_D, K_D, Y_D, P); } +template +static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D, + const unsigned int BlocksM, const unsigned int BlocksN, + const unsigned int shmem_bytes, + const param_t P, cudaStream_t st){ + + int id = ggml_cuda_get_device(); + + ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM, ksplit); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel<<>>(X_H, K_H, Y_H.get(), P); + + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; + const unsigned int blockx = (nrows + 511) / 512; + const dim3 block_nums(blockx, 1, 1); + const dim3 block_dims(512, 1, 1); + reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); +} + static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { @@ -829,39 +856,67 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - const unsigned int ksplit = 8; - if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { - ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); + // const unsigned int ksplit = 6; + // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { + printf("split factor info = %d, %d, %d \n", BlocksM, BlocksN, nsm / (BlocksM * BlocksN)); + if (BlocksM * BlocksN < nsm && nsm / (BlocksM * BlocksN) <= 8 ){ - cudaFuncSetAttribute(conv2d_implicit_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 - dim3 gridDim(BlocksN, BlocksM, ksplit); - dim3 blockDim(ThreadsN, ThreadsM); - - conv2d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - - const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; - const unsigned int blockx = (nrows + 511) / 512; - const dim3 block_nums(blockx, 1, 1); - const dim3 block_dims(512, 1, 1); - reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); - - } else { - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); - - cudaFuncSetAttribute(conv2d_implicit_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 - dim3 gridDim(BlocksN, BlocksM); - dim3 blockDim(ThreadsN, ThreadsM); - - conv2d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + int ks = nsm / (BlocksM * BlocksN); + printf("split factor init = %d \n", ks); + int j; + bool can_split = false; + for (j = ks; j >= 2; j--){ + if ((P.c * P.r * P.s) % (8*j) == 0){ + can_split = true; + break; + } + } + if(can_split){ + printf("split factor = %d \n", j); + if (j == 2) { + const unsigned int ksplit = 2; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 3) { + const unsigned int ksplit = 3; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 4) { + const unsigned int ksplit = 4; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 5) { + const unsigned int ksplit = 5; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 6) { + const unsigned int ksplit = 6; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 7) { + const unsigned int ksplit = 7; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 8) { + const unsigned int ksplit = 8; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } + return; + } } + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); } else{ conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 57edc02474..5af7da0a91 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -293,42 +293,38 @@ std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr } - -int main(void) -{ - ggml_time_init(); - - double time_iter0 = 0.0, time_iter1 = 0.0; - std::vector> configs = { - std::make_tuple(64,64,48,64,3,3), - std::make_tuple(320,320,104,152,3,3), - std::make_tuple(640,640,52,76,3,3), - std::make_tuple(640,640,104,152,3,3), - std::make_tuple(960,320,104,152,3,3), - std::make_tuple(1280,1280,26,38,3,3), - std::make_tuple(4,320,96,128,3,3), - std::make_tuple(320,4,96,128,3,3), - std::make_tuple(4,320,64,96,3,3), - std::make_tuple(320,4,64,96,3,3), - std::make_tuple(640,640,96,128,3,3), - std::make_tuple(1280,1280,26,38,1,1), - std::make_tuple(256,128,768,1024,3,3), - std::make_tuple(128,3,768,1024,3,3), - std::make_tuple(256,128,768,1024,1,1), - std::make_tuple(512,256,384,512,1,1), - std::make_tuple(1280,640,52,76,3,3), - std::make_tuple(1920,1280,26,38,3,3), - std::make_tuple(2560,1280,26,38,3,3), - std::make_tuple(320,1280,26,38,3,3), - std::make_tuple(512,512,104,152,3,3), - std::make_tuple(512,512,208,304,3,3), - std::make_tuple(512,256,416,608,3,3), - std::make_tuple(256,128,832,1216,3,3), - std::make_tuple(256,256,832,1216,3,3), - std::make_tuple(32,64,58,58,3,3) +static std::vector> configs = { + // std::make_tuple(64,64,48,64,3,3), + // std::make_tuple(320,320,104,152,3,3), + // std::make_tuple(640,640,52,76,3,3), + // std::make_tuple(640,640,104,152,3,3), + // std::make_tuple(960,320,104,152,3,3), + // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1920,640,32,32,3,3) + // std::make_tuple(4,320,96,128,3,3), + // std::make_tuple(320,4,96,128,3,3), + // std::make_tuple(4,320,64,96,3,3), + // std::make_tuple(320,4,64,96,3,3), + // std::make_tuple(640,640,96,128,3,3), + // std::make_tuple(1280,1280,26,38,1,1), + // std::make_tuple(256,128,768,1024,3,3), + // std::make_tuple(128,3,768,1024,3,3), + // std::make_tuple(256,128,768,1024,1,1), + // std::make_tuple(512,256,384,512,1,1), + // std::make_tuple(1280,640,52,76,3,3), + // std::make_tuple(1920,1280,26,38,3,3), + // std::make_tuple(2560,1280,26,38,3,3), + // std::make_tuple(320,1280,26,38,3,3), + // std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,208,304,3,3), + // std::make_tuple(512,256,416,608,3,3), + // std::make_tuple(256,128,832,1216,3,3), + // std::make_tuple(256,256,832,1216,3,3), + // std::make_tuple(32,64,58,58,3,3) // std::make_tuple(320,256,1024,1920) }; - std::vector> configs_sdxl_512 = { + +static std::vector> configs_sdxl_512 = { //512x512 std::make_tuple(4,320,64,64,3,3), std::make_tuple(320,320,64,64,3,3), @@ -434,7 +430,7 @@ int main(void) std::make_tuple(320,4,64,64,3,3) }; - std::vector> configs_sdxl_768 = { +static std::vector> configs_sdxl_768 = { //768x768 std::make_tuple(4,320,96,96,3,3), std::make_tuple(320,320,96,96,3,3), @@ -540,7 +536,7 @@ int main(void) std::make_tuple(320,4,96,96,3,3), }; - std::vector> configs_sdxl_1024 = { +static std::vector> configs_sdxl_1024 = { //1024x1024 std::make_tuple(4,320,128,128,3,3), std::make_tuple(320,320,128,128,3,3), @@ -646,10 +642,17 @@ int main(void) std::make_tuple(320,4,128,128,3,3) }; + +int main(void) +{ + ggml_time_init(); + + double time_iter0 = 0.0, time_iter1 = 0.0; + int k = 0; - for (auto c : configs_sdxl_512){ - // for (auto c : configs){ + // for (auto c : configs_sdxl_512){ + for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); From a1fb3c150941cd1b5fedea79aa2053c348415aff Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 8 Nov 2025 16:45:59 -0500 Subject: [PATCH 073/122] fixed a bug now split-k can choose a better split factor --- ggml/src/ggml-cuda/conv2d-implicit.cu | 22 +++++++++++++++++----- ggml/src/ggml-cuda/conv2d-implicit.cuh | 4 ++-- tests/test-conv2d.cpp | 8 +++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1133626d14..1bf94476ab 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -856,13 +856,10 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - // const unsigned int ksplit = 6; // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { - printf("split factor info = %d, %d, %d \n", BlocksM, BlocksN, nsm / (BlocksM * BlocksN)); - if (BlocksM * BlocksN < nsm && nsm / (BlocksM * BlocksN) <= 8 ){ + if (BlocksM * BlocksN < nsm){ int ks = nsm / (BlocksM * BlocksN); - printf("split factor init = %d \n", ks); int j; bool can_split = false; for (j = ks; j >= 2; j--){ @@ -872,7 +869,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } } if(can_split){ - printf("split factor = %d \n", j); if (j == 2) { const unsigned int ksplit = 2; launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 9) { + const unsigned int ksplit = 9; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 10) { + const unsigned int ksplit = 10; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 11) { + const unsigned int ksplit = 11; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else { + const unsigned int ksplit = 12; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } return; } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 85936e42c6..35764c5b63 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -72,7 +72,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -273,7 +273,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + ki; - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 5af7da0a91..c460ca7d87 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -300,7 +300,9 @@ static std::vector> configs = { // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), - std::make_tuple(1920,640,32,32,3,3) + // std::make_tuple(1920,640,32,32,3,3) + std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), // std::make_tuple(4,320,64,96,3,3), @@ -651,8 +653,8 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_512){ - for (auto c : configs){ + for (auto c : configs_sdxl_768){ + // for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); From a3fb36fb71843e8f4be4a3f03d22dba05c9cc389 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 8 Nov 2025 18:47:12 -0500 Subject: [PATCH 074/122] make split-k condition check more robust --- ggml/src/ggml-cuda/conv2d-implicit.cu | 4 ++-- tests/test-conv2d.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1bf94476ab..3d086343d7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -859,7 +859,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { if (BlocksM * BlocksN < nsm){ - int ks = nsm / (BlocksM * BlocksN); + int ks = min(12, nsm / (BlocksM * BlocksN)); int j; bool can_split = false; for (j = ks; j >= 2; j--){ @@ -909,7 +909,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const unsigned int ksplit = 11; launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); - } else { + } else if(j == 12) { const unsigned int ksplit = 12; launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index c460ca7d87..c6bdad23eb 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -653,7 +653,7 @@ int main(void) int k = 0; - for (auto c : configs_sdxl_768){ + for (auto c : configs_sdxl_1024){ // for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), From 6106e9068bf51ae19e1e0cdfc00bc65159d7ad07 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 8 Nov 2025 19:35:29 -0500 Subject: [PATCH 075/122] make CI happy --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 3d086343d7..da2a6868a9 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -857,7 +857,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { - if (BlocksM * BlocksN < nsm){ + if (BlocksM * BlocksN < (unsigned int)nsm){ int ks = min(12, nsm / (BlocksM * BlocksN)); int j; From a2db92f41c0708e9c39b2039c019719aa22de72f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 8 Nov 2025 20:33:05 -0500 Subject: [PATCH 076/122] make CI happy --- tests/test-conv2d.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index c6bdad23eb..db43cf1847 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -653,8 +653,8 @@ int main(void) int k = 0; - for (auto c : configs_sdxl_1024){ - // for (auto c : configs){ + // for (auto c : configs_sdxl_1024){ + for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); From 8e0e944b70c71d10826b56e59219ea416d59e75b Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 9 Nov 2025 00:14:56 -0500 Subject: [PATCH 077/122] reduced uncoalesced global access in filter transpose --- ggml/src/ggml-cuda/conv2d-implicit.cu | 78 +++++++++++++++++++++++++-- tests/test-conv2d.cpp | 8 +-- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index da2a6868a9..ad3f50c85e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -10,6 +10,7 @@ constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 #define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 +#define CUDA_NCHW_2_NHWC_BLOCK_C 64 //currently not use; in future for split-k kernels @@ -64,6 +65,45 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + const unsigned int tx = threadIdx.x; + const unsigned int bx = blockIdx.x; + const unsigned int by = blockIdx.y; + // int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + // int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset + // int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + + __shared__ src_T tile[rs*blk_c]; + + for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ + + const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; + if(imat >= nmat) + break; + for (unsigned int j = 0; j < rs; j++){ + const unsigned int row = (j * blk_c + tx) % rs; + const unsigned int col = (j * blk_c + tx) / rs; + const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; + if (src_index < ne) { + tile[row * blk_c + col] = src[src_index]; + } + } + __syncthreads(); + + for (unsigned int j = 0; j < rs; j++){ + const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; + if(dst_index < ne){ + dst[dst_index] = ggml_cuda_cast(tile[j*blk_c+tx]); + } + } + } +} + template 1 || P.s > 1)) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { int id = ggml_cuda_get_device(); @@ -826,13 +867,40 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ne = P.c * P.r * P.s * P.k; ne01 = P.r * P.s; ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); - dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + if (ne01 > 1){ + dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, + 1) ; + if (ne01 == 25) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + }else if (ne01 == 16) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + }else if (ne01 == 9) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 8) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 7) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 6) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 5) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 4) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 3) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 2) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else { + dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } + } const half *X_H = input_f16.get(); - const half *K_H = kernel_f16.get(); + const half *K_H = ne01 == 1 ? K_D : kernel_f16.get(); constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index db43cf1847..844cce8923 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -299,9 +299,9 @@ static std::vector> configs = { // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), @@ -653,8 +653,8 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_1024){ - for (auto c : configs){ + for (auto c : configs_sdxl_1024){ + // for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); From 5ed2c1b7875186e0a1bf8a96bc8b4e4dfbc438cb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 9 Nov 2025 00:51:51 -0500 Subject: [PATCH 078/122] reduce bank conflicts in filter transpose --- ggml/src/ggml-cuda/conv2d-implicit.cu | 56 +++++++++++++++++++-------- tests/test-conv2d.cpp | 6 +-- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index ad3f50c85e..00d96656ba 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -28,6 +28,19 @@ static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restri } } +constexpr uint32_t filter_swizzle_mask(uint32_t n, uint32_t m) { + if (n <= 1) return 1; + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + int count = 0; + while ((m >>= 1) != 0) + ++count; + return n << count; +} template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -65,7 +78,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } -template +template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ const int64_t nmat = ne / (ne00 * ne01); @@ -74,9 +87,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co const unsigned int tx = threadIdx.x; const unsigned int bx = blockIdx.x; const unsigned int by = blockIdx.y; - // int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; - // int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset - // int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; __shared__ src_T tile[rs*blk_c]; @@ -89,8 +99,10 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co const unsigned int row = (j * blk_c + tx) % rs; const unsigned int col = (j * blk_c + tx) / rs; const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; + unsigned int idx = row * blk_c + col; + idx = idx ^ ((idx & mask) >> 4); if (src_index < ne) { - tile[row * blk_c + col] = src[src_index]; + tile[idx] = src[src_index]; } } __syncthreads(); @@ -98,7 +110,9 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co for (unsigned int j = 0; j < rs; j++){ const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; if(dst_index < ne){ - dst[dst_index] = ggml_cuda_cast(tile[j*blk_c+tx]); + unsigned int idx = j*blk_c + tx; + idx = idx ^ ((idx & mask) >> 4); + dst[dst_index] = ggml_cuda_cast(tile[idx]); } } } @@ -872,25 +886,35 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, 1) ; if (ne01 == 25) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); }else if (ne01 == 16) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); }else if (ne01 == 9) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else if (ne01 == 8) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else if (ne01 == 7) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else if (ne01 == 6) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else if (ne01 == 5) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else if (ne01 == 4) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else if (ne01 == 3) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else if (ne01 == 2) { - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); } else { dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 844cce8923..b3d5e8c724 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -653,8 +653,8 @@ int main(void) int k = 0; - for (auto c : configs_sdxl_1024){ - // for (auto c : configs){ + // for (auto c : configs_sdxl_1024){ + for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); @@ -671,7 +671,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); From 496c3599c6ee9bbc377d7fb8d25cc8fe9bbdc330 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 9 Nov 2025 09:23:14 -0500 Subject: [PATCH 079/122] add loop unrolling --- ggml/src/ggml-cuda/conv2d-implicit.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 00d96656ba..99fa1925d5 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -90,11 +90,13 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co __shared__ src_T tile[rs*blk_c]; +#pragma unroll for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; if(imat >= nmat) break; +#pragma unroll for (unsigned int j = 0; j < rs; j++){ const unsigned int row = (j * blk_c + tx) % rs; const unsigned int col = (j * blk_c + tx) / rs; @@ -106,7 +108,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } __syncthreads(); - +#pragma unroll for (unsigned int j = 0; j < rs; j++){ const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; if(dst_index < ne){ From 1fdcb05dc8740864e47468fe4946d8bfea823b06 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 11:47:56 -0500 Subject: [PATCH 080/122] increase maximum split factor to 16; use better heuristics to choose split-K factor, reducing tail effect --- ggml/src/ggml-cuda/conv2d-implicit.cu | 76 ++++++++++++++++----------- tests/test-conv2d.cpp | 2 +- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 99fa1925d5..09f317e2d3 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,4 +1,5 @@ // #include +#include #include "ggml.h" #include "common.cuh" #include "convert.cuh" @@ -951,61 +952,72 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { - if (BlocksM * BlocksN < (unsigned int)nsm){ - - int ks = min(12, nsm / (BlocksM * BlocksN)); - int j; - bool can_split = false; - for (j = ks; j >= 2; j--){ + if (BlocksM * BlocksN < 2*(unsigned int)nsm){ + int j, max_remaining_waves = -1, candidate = -1; + int ks = min(16, nsm / (BlocksM * BlocksN)); + if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) + ks = 16; + for (j = 2; j <= ks; j++){ + const int remainder = (BlocksM * BlocksN * j) % nsm; if ((P.c * P.r * P.s) % (8*j) == 0){ - can_split = true; - break; + if (remainder == 0) { + candidate = j; + max_remaining_waves = 0; + break; + } else if (remainder > max_remaining_waves) { + max_remaining_waves = remainder; + candidate = j; + } } } - if(can_split){ + + if(candidate != -1){ + j = candidate; + // printf(" choosing %d, %d \n", j, max_remaining_waves); if (j == 2) { - const unsigned int ksplit = 2; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 3) { - const unsigned int ksplit = 3; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 4) { - const unsigned int ksplit = 4; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 5) { - const unsigned int ksplit = 5; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 6) { - const unsigned int ksplit = 6; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 7) { - const unsigned int ksplit = 7; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 8) { - const unsigned int ksplit = 8; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 9) { - const unsigned int ksplit = 9; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 10) { - const unsigned int ksplit = 10; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 11) { - const unsigned int ksplit = 11; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); - } else if(j == 12) { - const unsigned int ksplit = 12; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 13) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 14) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 15) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 16) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } return; diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index b3d5e8c724..a5f2a6aef6 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -653,7 +653,7 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_1024){ + // for (auto c : configs_sdxl_768){ for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), From a660d4d45d5184b2910d9b7b9f59e01bd36b69a8 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 12:39:50 -0500 Subject: [PATCH 081/122] get rid of a convert unary kernel call and fuse the type cast into conv epilogue --- ggml/src/ggml-cuda/conv2d-implicit.cu | 32 +++++++++++++-------------- tests/test-conv2d.cpp | 6 ++--- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 09f317e2d3..874d40e80b 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -582,11 +582,11 @@ __device__ __forceinline__ void ldmatrix_b( #endif } -template static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, - half * __restrict__ output, + T * __restrict__ output, const param_t param) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -763,10 +763,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint outOffset = z * NKPQ + n * KPQ + row * PQ + col; - output[outOffset] = res_[0]; + output[outOffset] = ggml_cuda_cast(res_[0]); } else { const uint outOffset = n * KPQ + row * PQ + col; - output[outOffset] = res_[0]; + output[outOffset] = ggml_cuda_cast(res_[0]); } } if (n < param.n && row+1 < param.k && col < PQ) { @@ -774,10 +774,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint outOffset = z * NKPQ + n * KPQ + (row+1) * PQ + col; - output[outOffset] = res_[1]; + output[outOffset] = ggml_cuda_cast(res_[1]); } else { const uint outOffset = n * KPQ + (row+1) * PQ + col; - output[outOffset] = res_[1]; + output[outOffset] = ggml_cuda_cast(res_[1]); } } } @@ -848,12 +848,12 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, int id = ggml_cuda_get_device(); ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); - cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncSetAttribute(conv2d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM, ksplit); dim3 blockDim(ThreadsN, ThreadsM); - conv2d_implicit_kernel<<>>(X_H, K_H, Y_H.get(), P); const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; @@ -866,7 +866,7 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { + if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0) { int id = ggml_cuda_get_device(); @@ -883,8 +883,10 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ne = P.c * P.r * P.s * P.k; ne01 = P.r * P.s; - ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + // ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); + ggml_cuda_pool_alloc kernel_f16(ctx.pool(id)); if (ne01 > 1){ + kernel_f16.alloc(ne); dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, 1) ; @@ -973,7 +975,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa if(candidate != -1){ j = candidate; - // printf(" choosing %d, %d \n", j, max_remaining_waves); if (j == 2) { launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); @@ -1023,18 +1024,15 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa return; } } - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); - cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncSetAttribute(conv2d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); - conv2d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + <<>>(X_H, K_H, Y_D, P); } else{ conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index a5f2a6aef6..23a3aab366 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -299,7 +299,7 @@ static std::vector> configs = { // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - std::make_tuple(1280,1280,26,38,3,3), + // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) // std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(320,640,32,32,3,3), @@ -317,7 +317,7 @@ static std::vector> configs = { // std::make_tuple(1920,1280,26,38,3,3), // std::make_tuple(2560,1280,26,38,3,3), // std::make_tuple(320,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), + std::make_tuple(512,512,104,152,3,3), // std::make_tuple(512,512,208,304,3,3), // std::make_tuple(512,256,416,608,3,3), // std::make_tuple(256,128,832,1216,3,3), @@ -714,7 +714,7 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // // for(int i = 0; i < conv2d_data.size(); i++) { + // for(int i = 0; i < conv2d_data.size(); i++) { // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { // printf("(%7.3f, %7.3f, %.2f, %d) \n", From fac6f0adc3daed7fa7620f00e2fd9fe6350e20eb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 20:05:39 -0500 Subject: [PATCH 082/122] add missing batch index bounds check --- ggml/src/ggml-cuda/conv2d-implicit.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 35764c5b63..981a183fd9 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -146,7 +146,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; } else{ @@ -214,7 +214,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; } else{ From c33e4301dcdda9006804bd12a073f55d15af57b6 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 12 Nov 2025 10:26:01 -0500 Subject: [PATCH 083/122] m16n8k16 mma works; to be cleaned up --- ggml/src/ggml-cuda/conv2d-implicit.cu | 325 +++++++++++++++++++------- tests/test-conv2d.cpp | 6 +- 2 files changed, 250 insertions(+), 81 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 874d40e80b..1fceeb9a6e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -343,13 +343,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, - half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); - static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); + // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); - uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + // uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -390,7 +392,104 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + // : "r"(src_addr) + // ); + + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + // : "r"(src_addr + 64 * smem_stride_) + // ); + + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + // : "r"(src_addr + 96 * smem_stride_) + // ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][2]), "=r"(reg_[0][0][3]), "=r"(reg_[1][0][2]), "=r"(reg_[1][0][3]) + : "r"(src_addr) + ); + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0][2]), "=r"(reg_[2][0][3]), "=r"(reg_[3][0][2]), "=r"(reg_[3][0][3]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][0][2]), "=r"(reg_[4][0][3]), "=r"(reg_[5][0][2]), "=r"(reg_[5][0][3]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][0][2]), "=r"(reg_[6][0][3]), "=r"(reg_[7][0][2]), "=r"(reg_[7][0][3]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b110000; + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + // : "r"(src_addr) + // ); + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + // : "r"(src_addr + 64 * smem_stride_) + // ); + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + // : "r"(src_addr + 96 * smem_stride_) + // ); + + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -398,7 +497,7 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr) ); - // 1 + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -406,7 +505,7 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 32 * smem_stride_) ); - // 1 + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -414,7 +513,7 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 64 * smem_stride_) ); - // 1 + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -422,46 +521,45 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 96 * smem_stride_) ); - src_addr ^= 0b110000; - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) - : "r"(src_addr) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) - : "r"(src_addr + 96 * smem_stride_) - ); src_addr ^= 0b10000; + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + // : "r"(src_addr) + // ); + + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + // : "r"(src_addr + 64 * smem_stride_) + // ); + + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + // : "r"(src_addr + 96 * smem_stride_) + // ); + // 3 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "=r"(reg_[0][1][2]), "=r"(reg_[0][1][3]), "=r"(reg_[1][1][2]), "=r"(reg_[1][1][3]) : "r"(src_addr) ); @@ -469,7 +567,7 @@ __device__ __forceinline__ void ldmatrix_a( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "=r"(reg_[2][1][2]), "=r"(reg_[2][1][3]), "=r"(reg_[3][1][2]), "=r"(reg_[3][1][3]) : "r"(src_addr + 32 * smem_stride_) ); @@ -477,7 +575,7 @@ __device__ __forceinline__ void ldmatrix_a( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "=r"(reg_[4][1][2]), "=r"(reg_[4][1][3]), "=r"(reg_[5][1][2]), "=r"(reg_[5][1][3]) : "r"(src_addr + 64 * smem_stride_) ); @@ -485,7 +583,7 @@ __device__ __forceinline__ void ldmatrix_a( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3]) : "r"(src_addr + 96 * smem_stride_) ); #else @@ -498,14 +596,17 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, - half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] + // half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); - uint32_t (®_) [4][8] = reinterpret_cast(reg); + // uint32_t (®_) [4][8] = reinterpret_cast(reg); + uint32_t (®_) [2][8][2] = reinterpret_cast(reg); unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -513,10 +614,25 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + // : "r"(src_addr) + // ); + + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "=r"(reg_[0][0][0]), "=r"(reg_[0][1][0]), "=r"(reg_[0][2][0]), "=r"(reg_[0][3][0]) : "r"(src_addr) ); @@ -524,55 +640,97 @@ __device__ __forceinline__ void ldmatrix_b( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0]) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b10000; + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + // : "r"(src_addr) + // ); + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][1]), "=r"(reg_[0][1][1]), "=r"(reg_[0][2][1]), "=r"(reg_[0][3][1]) + : "r"(src_addr) + ); asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1]) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b110000; + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + // : "r"(src_addr) + // ); + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0][0]), "=r"(reg_[1][1][0]), "=r"(reg_[1][2][0]), "=r"(reg_[1][3][0]) + : "r"(src_addr) + ); asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0]) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b10000; + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + // : "r"(src_addr) + // ); + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0][1]), "=r"(reg_[1][1][1]), "=r"(reg_[1][2][1]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1]) : "r"(src_addr + 32 * smem_stride_) ); #else @@ -602,7 +760,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const unsigned int NKPQ = param.n * KPQ; // loop bounds, constexpr where possible allows for loop unrolling - constexpr unsigned int mma_tiles_per_warp_k = 4; + constexpr unsigned int mma_tiles_per_warp_k = 2; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; @@ -629,13 +787,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; - uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; - uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); - half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); - half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); // accumulators start at 0 for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ @@ -685,15 +843,26 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + // asm volatile ( + // "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + // "{%0, %1}, " + // "{%2, %3}, " + // "{%4}, " + // "{%5, %6};" + // : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + // : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + // "r"(B_register[mma_k][mma_n]) + // "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + // ); asm volatile ( - "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, " - "{%2, %3}, " - "{%4}, " - "{%5, %6};" + "{%2, %3, %4, %5}, " + "{%6, %7}, " + "{%8, %9};" : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), - "r"(B_register[mma_k][mma_n]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]), + "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 23a3aab366..8ee0747989 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -301,7 +301,7 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - // std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), @@ -317,7 +317,7 @@ static std::vector> configs = { // std::make_tuple(1920,1280,26,38,3,3), // std::make_tuple(2560,1280,26,38,3,3), // std::make_tuple(320,1280,26,38,3,3), - std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,104,152,3,3), // std::make_tuple(512,512,208,304,3,3), // std::make_tuple(512,256,416,608,3,3), // std::make_tuple(256,128,832,1216,3,3), @@ -653,7 +653,7 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_768){ + // for (auto c : configs_sdxl_1024){ for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), From ea438d8b0e53d72f25b4f45c290afb57688da29e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 12 Nov 2025 11:32:27 -0500 Subject: [PATCH 084/122] trying to reduce integer ops; simply code --- ggml/src/ggml-cuda/conv2d-implicit.cu | 46 +++++++++++---------------- tests/test-conv2d.cpp | 18 +++++------ 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1fceeb9a6e..654d2dffe4 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -805,6 +805,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } + const unsigned int A_warp_tile_offset = warp_m * WM * BK; + const unsigned int B_warp_tile_offset = warp_n * WN * BK; + static_assert(BM == 256); static_assert(BN == 256); static_assert(BK == 32); @@ -825,13 +828,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, __syncthreads(); if (block_k != num_block_tiles_k){ - const half* A_block_gmem = input; - const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param); } - half* A_warp_tile = A_block_smem + (warp_m * WM * BK); - half* B_warp_tile = B_block_smem + (warp_n * WN * BK); + half* A_warp_tile = A_block_smem + A_warp_tile_offset; + half* B_warp_tile = B_block_smem + B_warp_tile_offset; ldmatrix_a(A_warp_tile, A_register_); ldmatrix_b(B_warp_tile, B_register_); @@ -886,23 +887,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint lane_id = threadIdx.x % WARPSIZE; const uint mma_row = lane_id / 4; const uint mma_col = lane_id % 4; - const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2; - const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2; + const uint warp_offset = warp_m * WM * BN/2 + warp_n * WN/2; + const uint output_lds_addr = warp_offset + lane_id * BN/2; + const uint output_sts_addr = warp_offset + mma_row * BN/2 + mma_col * 2; const uint m_idx = block_n * BN + warp_n * WN; const uint n_idx = block_m * BM + warp_m * WM + lane_id; #pragma unroll for (int i = 0; i < 2; ++i) { + const unsigned int i_offset = i * mma_tiles_per_warp_n/2; __syncthreads(); #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { - for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) + const unsigned int mma_m_offset = output_sts_addr + mma_m * MMA_M * BN / 2; + for (unsigned int mma_n = i_offset; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint idx = output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx = mma_m_offset + (mma_n - i_offset) * MMA_N; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); @@ -913,6 +916,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } __syncthreads(); + const unsigned int m_i_wn = m_idx + i * WN / 2; #pragma unroll for (int subk = 0; subk < WN / 4; ++subk){ @@ -925,29 +929,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); + uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*16*BN])); // 32*BN/2 = 16*BN half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { - if constexpr (ksplit > 0) { - const uint outOffset = z * NKPQ + - n * KPQ + - row * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[0]); - } else { - const uint outOffset = n * KPQ + row * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[0]); - } + const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + row * PQ + col; + output[outOffset] = ggml_cuda_cast(res_[0]); } if (n < param.n && row+1 < param.k && col < PQ) { - if constexpr (ksplit > 0) { - const uint outOffset = z * NKPQ + - n * KPQ + - (row+1) * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[1]); - } else { - const uint outOffset = n * KPQ + (row+1) * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[1]); - } + const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + (row+1) * PQ + col; + output[outOffset] = ggml_cuda_cast(res_[1]); } } } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 8ee0747989..90ef1e5237 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -714,15 +714,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 9f498d29f1a3652bdbe426dd4802cc616b653aa2 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 12 Nov 2025 11:55:15 -0500 Subject: [PATCH 085/122] only enable m16n8k16 on ampere or above --- ggml/src/ggml-cuda/conv2d-implicit.cu | 388 ++++++++++++++------------ tests/test-conv2d.cpp | 18 +- 2 files changed, 226 insertions(+), 180 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 654d2dffe4..529a0b50fd 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -343,15 +343,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] +#else + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] +#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); - // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); +#else + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#endif - // uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); +#else + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); +#endif unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -392,38 +402,8 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) - // : "r"(src_addr) - // ); - - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) - // : "r"(src_addr + 64 * smem_stride_) - // ); - - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) - // : "r"(src_addr + 96 * smem_stride_) - // ); - + // 1 +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -455,41 +435,43 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 96 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + : "r"(src_addr) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + : "r"(src_addr + 96 * smem_stride_) + ); +#endif + src_addr ^= 0b110000; - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) - // : "r"(src_addr) - // ); - - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) - // : "r"(src_addr + 64 * smem_stride_) - // ); - - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) - // : "r"(src_addr + 96 * smem_stride_) - // ); - // 2 +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -520,42 +502,42 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) : "r"(src_addr + 96 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + : "r"(src_addr) + ); + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + : "r"(src_addr + 96 * smem_stride_) + ); +#endif src_addr ^= 0b10000; - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) - // : "r"(src_addr) - // ); - - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) - // : "r"(src_addr + 64 * smem_stride_) - // ); - - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) - // : "r"(src_addr + 96 * smem_stride_) - // ); - // 3 +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -586,6 +568,38 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3]) : "r"(src_addr + 96 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "r"(src_addr + 96 * smem_stride_) + ); +#endif #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -596,17 +610,26 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, - // half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] +#else + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); +#else + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#endif static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); - // uint32_t (®_) [4][8] = reinterpret_cast(reg); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE uint32_t (®_) [2][8][2] = reinterpret_cast(reg); +#else + uint32_t (®_) [4][8] = reinterpret_cast(reg); +#endif unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -614,21 +637,7 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) - // : "r"(src_addr) - // ); - - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -636,30 +645,32 @@ __device__ __forceinline__ void ldmatrix_b( : "r"(src_addr) ); - asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif src_addr ^= 0b10000; - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) - // : "r"(src_addr) - // ); - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -673,23 +684,25 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif src_addr ^= 0b110000; - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) - // : "r"(src_addr) - // ); - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -703,23 +716,25 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif src_addr ^= 0b10000; - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) - // : "r"(src_addr) - // ); - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -733,6 +748,21 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -760,7 +790,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const unsigned int NKPQ = param.n * KPQ; // loop bounds, constexpr where possible allows for loop unrolling +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE constexpr unsigned int mma_tiles_per_warp_k = 2; +#else + constexpr unsigned int mma_tiles_per_warp_k = 4; +#endif constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; @@ -787,14 +821,23 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; +#else + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; +#endif // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); - +#else + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); +#endif // accumulators start at 0 for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ @@ -844,17 +887,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ - // asm volatile ( - // "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " - // "{%0, %1}, " - // "{%2, %3}, " - // "{%4}, " - // "{%5, %6};" - // : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - // : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), - // "r"(B_register[mma_k][mma_n]) - // "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) - // ); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, " @@ -866,6 +899,19 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); +#else + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); +#endif } } } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 90ef1e5237..8ee0747989 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -714,15 +714,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } + // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 093951184636bb950354e666739f98c4085066bf Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 13 Nov 2025 15:45:43 -0500 Subject: [PATCH 086/122] change mac loop to match cutlass --- ggml/src/ggml-cuda/conv2d-implicit.cu | 44 ++++++-- ggml/src/ggml-cuda/conv2d-implicit.cuh | 145 +++++++++++++++++++++---- tests/test-conv2d.cpp | 6 +- 3 files changed, 163 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 529a0b50fd..bd67ac2b86 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -781,9 +781,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; - const unsigned int K = param.c * param.r * param.s; + const unsigned int K = param.c; const uint inChannelOffset = param.c * param.w; - const uint weightKOffset = K; + const uint weightKOffset = param.c * param.r * param.s; const unsigned int PQ = param.Ow * param.Oh; const unsigned int KPQ = param.k * PQ; @@ -799,18 +799,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; - const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const unsigned int ks = (ksplit > 0) ? (K + ksplit - 1) / ksplit : K; const unsigned int start_k = (ksplit > 0) ? z * ks : 0; - const unsigned int end_k = min(start_k + ks, weightKOffset); + const unsigned int end_k = min(start_k + ks, K); const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8; + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int A_K_STRID = BM / ROW_STEP; + constexpr unsigned int B_K_STRID = BN / ROW_STEP; + unsigned int masks_a[A_K_STRID][2]; + unsigned int element_offset_a[A_K_STRID]; // calculate block/warp indices const unsigned int block_m = blockIdx.y; const unsigned int block_n = blockIdx.x; const unsigned int warp_m = threadIdx.y; const unsigned int warp_n = threadIdx.x / 32; + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // double buffering extern __shared__ half shmem[]; @@ -858,12 +865,21 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, float4 A_gmem_cache_reg[4]; float4 B_gmem_cache_reg[4]; + + prepareIteratorA(thread_idx, masks_a, element_offset_a, param); + + // prefetch the first block tile of A,B into shared memory const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, start_k, end_k, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param); + int s = 0; + int r = 0; + while (r < param.r) { + // for (int r = 0; r < param.r; ++r) { + + tileMemcpySwizzleA(A_block_gmem, A_block_smem, r, s, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, r, s, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -871,8 +887,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, __syncthreads(); if (block_k != num_block_tiles_k){ - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param); } half* A_warp_tile = A_block_smem + A_warp_tile_offset; half* B_warp_tile = B_block_smem + B_warp_tile_offset; @@ -926,7 +942,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem); tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); } - } + } // iter block_k + + s++; + if (s == param.s) { + s = 0; + r++; + } + } // iter r // reuse smem half *smemoutput = shmem; @@ -1166,7 +1189,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ks = 16; for (j = 2; j <= ks; j++){ const int remainder = (BlocksM * BlocksN * j) % nsm; - if ((P.c * P.r * P.s) % (8*j) == 0){ + // if ((P.c * P.r * P.s) % (8*j) == 0){ + if ((P.c) % (8*j) == 0){ if (remainder == 0) { candidate = j; max_remaining_waves = 0; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 981a183fd9..0f25b38dd6 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -26,12 +26,89 @@ typedef struct{ } param_t; +/// Clears the predicates + +template +__host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { + +#pragma unroll + for (int s = 0; s < K_STRID; ++s) { + masks_[s][0] = clear ? 0 : masks_[s][0]; + masks_[s][1] = clear ? 0 : masks_[s][1]; + } +} + +template +__device__ void prepareIteratorA(const int thread_idx, + unsigned int masks[][2], + unsigned int element_offset[], + const param_t param){ + int offset_n[A_K_STRID]; + int offset_p[A_K_STRID]; + int offset_q[A_K_STRID]; + + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int chw = param.c * param.h * param.w; + +#pragma unroll + for (int s = 0; s < A_K_STRID; ++s) { + + // pointer_[s] = reinterpret_cast(ptr); + + // int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + const unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + offset_n[s] = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + offset_p[s] = fastdiv(npq_res, param.OW_fastdiv); //* param.u - param.p; + offset_q[s] = fastmodulo(npq_res, param.OW_fastdiv); // * param.v - param.q; + const int h = offset_p[s] * param.u - param.p; + const int w = offset_q[s] * param.v - param.q; + + // if(threadIdx.x < 32 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) + // printf("%d, %d : %d, %d, %d, %d offset (%d, %d, %d), kele %llu Kcont %d\n ", thread_idx, s, + // // printf("[%s - %d] %d, %d : %d, %d, %d, %d\n ", __FUNCTION__, __LINE__, thread_idx, s, + // threadblock_offset.row(), thread_coord.strided(), ThreadMap::Delta::kStrided, + // offset_npq, offset_n[s], offset_p[s], offset_q[s], AccessType::kElements, + // ThreadMap::Iterations::kContiguous); + + element_offset[s] = offset_n[s] * chw + h * param.c * param.w + w * param.c; + thread_row += ROW_STEP; + } + + clear_mask(masks); + + for (int r = 0; r < param.r; ++r) { +#pragma unroll + for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) { + const int h = offset_p[s_idx] * param.u - param.p + r * param.d_h; + + bool pred = (offset_n[s_idx] < param.n && h >= 0 && h < param.h); + masks[s_idx][0] |= (pred << r); + } + } + + for (int s = 0; s < param.s; ++s) { +#pragma unroll + for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) { + const int w = offset_q[s_idx] * param.v - param.q + s * param.d_w; + bool pred = (w >= 0 && w < param.w); + masks[s_idx][1] |= (pred << s); + } + } +} + // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template __device__ __forceinline__ void tileMemcpySwizzleB( const half* src, half* dst, + const unsigned int curR, + const unsigned int curS, const unsigned int start_k, const unsigned int end_k, const unsigned int src_stride, @@ -60,10 +137,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int ki = start_k+thread_col*8; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + // const unsigned int ki = (curR*param.s+curS)*param.c + start_k+thread_col*8; + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curC = start_k+thread_col*8; + const unsigned int ki = (curR*param.s+curS)*param.c + curC; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -72,7 +151,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -95,6 +174,11 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( const half* src, half* dst, + const unsigned int curR, + const unsigned int curS, + unsigned int masks[][2], + unsigned int element_offset[], + const unsigned int thread_idx, const unsigned int start_k, const unsigned int end_k, const unsigned int inChannelOffset, @@ -115,7 +199,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile @@ -126,11 +210,27 @@ __device__ __forceinline__ void tileMemcpySwizzleA( const unsigned int ki = start_k+thread_col*8; const unsigned int chw = param.c * param.h * param.w; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - - + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = ki; + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++){ + // bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + // // apply swizzle to the dst index + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + // if (valid && ki < end_k){ + // if(element_offset[i]+curC >= 327680 || element_offset[i]+curC < 0) + // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, + // i, element_offset[i], curR, curS, curC); + // dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; + // } else{ + // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; @@ -170,7 +270,8 @@ unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadA( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], - // const unsigned int src_stride, + const unsigned int curR, + const unsigned int curS, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -199,9 +300,10 @@ __device__ __forceinline__ void tileMemcpyLoadA( const unsigned int ki = start_k+block_k+thread_col*8; const unsigned int chw = param.c * param.h * param.w; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = ki; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -240,6 +342,8 @@ unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadB( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int curR, + const unsigned int curS, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -265,15 +369,16 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int ki = start_k+block_k+thread_col*8; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curC = start_k+block_k+thread_col*8; + const unsigned int ki = (curR*param.s+curS)*param.c + curC; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + ki; - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 8ee0747989..daac5c9605 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -301,7 +301,9 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(32,8,24,24,3,3), + std::make_tuple(640,640,64,64,3,3), // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), @@ -671,7 +673,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 0; + int iterations = 20; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); From 8bfb7ed2f246dbd63c8fd15d2afe8616172e2b1f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 13 Nov 2025 16:32:27 -0500 Subject: [PATCH 087/122] restore smem pointer at teh end of evry rs loop --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index bd67ac2b86..d2b5ee6d33 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -949,6 +949,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, s = 0; r++; } + A_block_smem = shmem; + B_block_smem = &shmem[BM * BK]; } // iter r // reuse smem From 63c53fe1f1568f11ac69e07ffdce2f84968675ab Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 13 Nov 2025 18:44:32 -0500 Subject: [PATCH 088/122] WIP: move rs loop into block-k-loop following cutlass --- ggml/src/ggml-cuda/conv2d-implicit.cu | 63 ++++++++++++++++++--------- tests/test-conv2d.cpp | 24 +++++----- 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d2b5ee6d33..d04c379bb0 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -803,6 +803,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const unsigned int start_k = (ksplit > 0) ? z * ks : 0; const unsigned int end_k = min(start_k + ks, K); const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + const unsigned int num_block_tiles_krs = num_block_tiles_k * param.r * param.s; constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8; constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; @@ -867,26 +868,49 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, prepareIteratorA(thread_idx, masks_a, element_offset_a, param); - // prefetch the first block tile of A,B into shared memory const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - int s = 0; - int r = 0; - while (r < param.r) { - // for (int r = 0; r < param.r; ++r) { - tileMemcpySwizzleA(A_block_gmem, A_block_smem, r, s, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, r, s, start_k, end_k, weightKOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param); int offset_direction = 1; - - for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){ + unsigned int block_k = 0; + unsigned int block_krs = 1; + // for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){ + int s = 0; + int r = 0; + while (block_k < num_block_tiles_k){ __syncthreads(); - if (block_k != num_block_tiles_k){ + // moves to the next tile + int next_idx = 0; + ++s; + if (s == param.s) { + s = 0; + ++r; + if (r < param.r) { + next_idx = 1; + } else { + r = 0; + next_idx = 2; + } + } + if (next_idx == 2) { + ++block_k; + } + // if(block_k == num_block_tiles_k) + // break; + + // if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d %d \n", s, r, block_k, next_idx, block_krs, num_block_tiles_k); + // } + + // if (block_k != num_block_tiles_k){ + if (block_krs != num_block_tiles_krs){ tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param); } @@ -932,7 +956,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } - if (block_k != num_block_tiles_k) + // if (block_k != num_block_tiles_k) + if (block_krs != num_block_tiles_krs) { // switch smem buffers each iteration A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction; @@ -942,16 +967,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem); tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); } - } // iter block_k - s++; - if (s == param.s) { - s = 0; - r++; - } - A_block_smem = shmem; - B_block_smem = &shmem[BM * BK]; - } // iter r + block_krs++; + + } + // A_block_smem = shmem; + // B_block_smem = &shmem[BM * BK]; + + // } // iter block_k // reuse smem half *smemoutput = shmem; diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index daac5c9605..e3968f28b8 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -301,9 +301,9 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - // std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(32,8,24,24,3,3), - std::make_tuple(640,640,64,64,3,3), + // std::make_tuple(640,640,64,64,3,3), // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), @@ -673,7 +673,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 20; + int iterations = 0; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); @@ -716,15 +716,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 7d99222a61004c6c05227ccfc0a469b6f7c5541d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 13 Nov 2025 22:08:41 -0500 Subject: [PATCH 089/122] WIP: debugging --- ggml/src/ggml-cuda/conv2d-implicit.cu | 37 +++++++- ggml/src/ggml-cuda/conv2d-implicit.cuh | 125 ++++++++++++++----------- 2 files changed, 104 insertions(+), 58 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d04c379bb0..5ec616a978 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -811,7 +811,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int B_K_STRID = BN / ROW_STEP; unsigned int masks_a[A_K_STRID][2]; - unsigned int element_offset_a[A_K_STRID]; + int64_t element_offset_a[A_K_STRID]; // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -867,6 +867,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, float4 B_gmem_cache_reg[4]; + prepareIteratorA(thread_idx, masks_a, element_offset_a, param); // prefetch the first block tile of A,B into shared memory @@ -874,7 +875,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, + thread_idx, start_k, end_k, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -899,6 +901,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, next_idx = 2; } } + + add_byte_offset(element_offset_a, param.inc_next[next_idx]); + if (next_idx == 2) { ++block_k; } @@ -911,7 +916,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if (block_k != num_block_tiles_k){ if (block_krs != num_block_tiles_krs){ - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, + masks_a, element_offset_a, thread_idx, block_k * BK, + start_k, end_k, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param); } half* A_warp_tile = A_block_smem + A_warp_tile_offset; @@ -1096,7 +1103,7 @@ template<<>>(X_H, K_H, Y_H.get(), P); + int64_t inc[3]; + // next S + inc[0] = int64_t(P.c) * P.d_w; + // next R + inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; + // next C + inc[2] = BK - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; + memcpy(P.inc_next, inc, sizeof(int64_t)*3); + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; const unsigned int blockx = (nrows + 511) / 512; const dim3 block_nums(blockx, 1, 1); @@ -1116,7 +1132,7 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); } -static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, param_t P, cudaStream_t st) { // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0) { @@ -1279,6 +1295,15 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } } + int64_t inc[3]; + // next S + inc[0] = int64_t(P.c) * P.d_w; + // next R + inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; + // next C + inc[2] = BK_dim - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; + memcpy(P.inc_next, inc, sizeof(int64_t)*3); + cudaFuncSetAttribute(conv2d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); @@ -1340,6 +1365,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint OC = kernel->ne[3]; // ouptut_chanles const uint B = input->ne[3]; // n_batches + + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW, init_fastdiv_values(KW*IC), init_fastdiv_values(OW), diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 0f25b38dd6..22b597f7bb 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -23,6 +23,7 @@ typedef struct{ uint3 RS_fastdiv; uint3 S_fastdiv; uint3 OHOW_fastdiv; + int64_t inc_next[3]; } param_t; @@ -38,13 +39,21 @@ __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) } } +template +__host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ +#pragma unroll + for (int s = 0; s < K_STRID; ++s) { + element_offset[s] += offset; + } +} + template __device__ void prepareIteratorA(const int thread_idx, unsigned int masks[][2], - unsigned int element_offset[], + int64_t element_offset[], const param_t param){ int offset_n[A_K_STRID]; int offset_p[A_K_STRID]; @@ -176,8 +185,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA( half* dst, const unsigned int curR, const unsigned int curS, - unsigned int masks[][2], - unsigned int element_offset[], + const unsigned int masks[][2], + const int64_t element_offset[], const unsigned int thread_idx, const unsigned int start_k, const unsigned int end_k, @@ -208,52 +217,52 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int ki = start_k+thread_col*8; + // const unsigned int ki = start_k+thread_col*8; const unsigned int chw = param.c * param.h * param.w; // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = ki; - // #pragma unroll - // for (unsigned int i = 0; i < NUM_ITERS; i++){ - // bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); - // // apply swizzle to the dst index - // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; - // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); - // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - // if (valid && ki < end_k){ - // if(element_offset[i]+curC >= 327680 || element_offset[i]+curC < 0) - // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, - // i, element_offset[i], curR, curS, curC); - // dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; - // } else{ - // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); - // } - // thread_row += ROW_STEP; - // } + const unsigned int curC = start_k+thread_col*8; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - // unsigned int inOffset = n * param.c * param.h * param.w; - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); // apply swizzle to the dst index unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ - const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + if (valid && curC < end_k){ + if(element_offset[i] >= 327680 || element_offset[i] < 0) + printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, + i, element_offset[i], curR, curS, curC); + dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]])[0]; } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } thread_row += ROW_STEP; } + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++){ + // unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + // // unsigned int inOffset = n * param.c * param.h * param.w; + // int curH = posh_ori + curR * param.d_h; // input h + // int curW = posw_ori + curS * param.d_w; // input w + // // apply swizzle to the dst index + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ + // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + // } else{ + // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } #else GGML_UNUSED(src); GGML_UNUSED(dst); @@ -272,6 +281,9 @@ __device__ __forceinline__ void tileMemcpyLoadA( float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int curR, const unsigned int curS, + const unsigned int masks[][2], + const int64_t element_offset[], + const unsigned int thread_idx, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -285,45 +297,52 @@ __device__ __forceinline__ void tileMemcpyLoadA( static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; - // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int ki = start_k+block_k+thread_col*8; - const unsigned int chw = param.c * param.h * param.w; + // const unsigned int ki = start_k+block_k+thread_col*8; + // const unsigned int chw = param.c * param.h * param.w; // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = ki; + const unsigned int curC = start_k+block_k+thread_col*8;; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - // unsigned int inOffset = n * param.c * param.h * param.w; - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ - const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + if (valid && curC < end_k) { + dst_reg[i] = reinterpret_cast(&src[element_offset[i]])[0]; } else{ dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } - thread_row += ROW_STEP; } + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++){ + // unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + // // unsigned int inOffset = n * param.c * param.h * param.w; + // int curH = posh_ori + curR * param.d_h; // input h + // int curW = posw_ori + curS * param.d_w; // input w + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ + // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + // } else{ + // dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } #else GGML_UNUSED(src); GGML_UNUSED(dst_reg); From b015e4b7dcb7fea5d0bed746fd7c0905869a64b1 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 14 Nov 2025 11:10:34 -0500 Subject: [PATCH 090/122] WIP: fixed bugs now results are correct --- ggml/src/ggml-cuda/conv2d-implicit.cu | 67 +++++++++++++++++--------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 66 ++++++++++++++++++------- tests/test-conv2d.cpp | 18 +++---- 3 files changed, 102 insertions(+), 49 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 5ec616a978..1b49a9beb1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -870,13 +870,28 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, prepareIteratorA(thread_idx, masks_a, element_offset_a, param); + // for(int kk =0; kk < A_K_STRID; kk++){ + // if(element_offset_a[kk] >= 327680) + // printf("%d, %d, %d, %d, %d, %lld \n", + // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, + // element_offset_a[kk]); + // } + + // if(threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // printf("A["); + // for(int kk =0; kk < A_K_STRID; kk++) + // printf("%f,", element_offset_a[kk]); + // printf("]\n"); + // } + + // prefetch the first block tile of A,B into shared memory const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, - thread_idx, start_k, end_k, inChannelOffset, param); + unsigned int curC = tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, + thread_idx, start_k, end_k, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -907,6 +922,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (next_idx == 2) { ++block_k; } + + // if(threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // printf("B %d,%d,%d [", s, r, block_k); + // for(int kk =0; kk < A_K_STRID; kk++){ + // if(element_offset_a[kk] >= 327680) + // printf("%d, %d, %d, %d, %d, %lld, %d, %d, %d %d, %lld\n", + // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, + // element_offset_a[kk], r, s, block_k, next_idx, param.inc_next[next_idx]); + // } + // threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // printf("%f,", element_offset_a[kk]); + // printf("]\n"); // if(block_k == num_block_tiles_k) // break; @@ -916,11 +943,12 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if (block_k != num_block_tiles_k){ if (block_krs != num_block_tiles_krs){ - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, + curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, masks_a, element_offset_a, thread_idx, block_k * BK, - start_k, end_k, inChannelOffset, param); + start_k, end_k, curC, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param); } + half* A_warp_tile = A_block_smem + A_warp_tile_offset; half* B_warp_tile = B_block_smem + B_warp_tile_offset; @@ -983,6 +1011,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // } // iter block_k + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf(" %u, %f\n", blockIdx.z, __half2float(acc_register_[0][0][0])); + // } + // reuse smem half *smemoutput = shmem; const uint lane_id = threadIdx.x % WARPSIZE; @@ -1116,15 +1148,6 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, conv2d_implicit_kernel<<>>(X_H, K_H, Y_H.get(), P); - int64_t inc[3]; - // next S - inc[0] = int64_t(P.c) * P.d_w; - // next R - inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; - // next C - inc[2] = BK - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; - memcpy(P.inc_next, inc, sizeof(int64_t)*3); - const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; const unsigned int blockx = (nrows + 511) / 512; const dim3 block_nums(blockx, 1, 1); @@ -1139,6 +1162,15 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa int id = ggml_cuda_get_device(); + int64_t inc[3]; + // next S + inc[0] = int64_t(P.c) * P.d_w; + // next R + inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; + // next C + inc[2] = - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; + memcpy(P.inc_next, inc, sizeof(int64_t)*3); + int64_t ne = P.c * P.h * P.w * P.n; int64_t ne00 = P.c; int64_t ne01 = P.h * P.w; @@ -1295,15 +1327,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } } - int64_t inc[3]; - // next S - inc[0] = int64_t(P.c) * P.d_w; - // next R - inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; - // next C - inc[2] = BK_dim - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; - memcpy(P.inc_next, inc, sizeof(int64_t)*3); - cudaFuncSetAttribute(conv2d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 22b597f7bb..982699f969 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -74,8 +74,8 @@ __device__ void prepareIteratorA(const int thread_idx, unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); offset_p[s] = fastdiv(npq_res, param.OW_fastdiv); //* param.u - param.p; offset_q[s] = fastmodulo(npq_res, param.OW_fastdiv); // * param.v - param.q; - const int h = offset_p[s] * param.u - param.p; - const int w = offset_q[s] * param.v - param.q; + const int h = offset_p[s] * (int)param.u - (int) param.p; + const int w = offset_q[s] * (int)param.v - (int) param.q; // if(threadIdx.x < 32 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) // printf("%d, %d : %d, %d, %d, %d offset (%d, %d, %d), kele %llu Kcont %d\n ", thread_idx, s, @@ -84,7 +84,12 @@ __device__ void prepareIteratorA(const int thread_idx, // offset_npq, offset_n[s], offset_p[s], offset_q[s], AccessType::kElements, // ThreadMap::Iterations::kContiguous); - element_offset[s] = offset_n[s] * chw + h * param.c * param.w + w * param.c; + element_offset[s] = offset_n[s] * (int64_t)chw + h * (int64_t)(param.c * param.w) + w * (int64_t)param.c; + + // if(element_offset[s] >= 327680) + // printf("(%d, %d, %d, %d, %d), %d, %lld, %d, %d, %d, %d, %d, %u, %u, %u \n", + // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, + // s, element_offset[s], offset_n[s], offset_p[s], offset_q[s], h, w, chw, param.c * param.w, param.c); thread_row += ROW_STEP; } @@ -180,12 +185,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB( // this is a special case of the above for when TILE_COLS == 32 template -__device__ __forceinline__ void tileMemcpySwizzleA( +__device__ __forceinline__ unsigned int tileMemcpySwizzleA( const half* src, half* dst, const unsigned int curR, const unsigned int curS, - const unsigned int masks[][2], + unsigned int masks[][2], const int64_t element_offset[], const unsigned int thread_idx, const unsigned int start_k, @@ -218,23 +223,29 @@ __device__ __forceinline__ void tileMemcpySwizzleA( const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // const unsigned int ki = start_k+thread_col*8; - const unsigned int chw = param.c * param.h * param.w; + // const unsigned int chw = param.c * param.h * param.w; // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = start_k+thread_col*8; + clear_mask(masks, curC >= end_k); + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); // apply swizzle to the dst index unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (valid && curC < end_k){ - if(element_offset[i] >= 327680 || element_offset[i] < 0) - printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, - i, element_offset[i], curR, curS, curC); - dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]])[0]; + // if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ + // printf(" %u, %u, %u, %u, %lld, %d\n", i, curR, curS, curC, element_offset[i], valid?1:0); + // } + // if (valid && curC < end_k){ + if (valid){ + // if(element_offset[i] >= 327680 || element_offset[i] < 0) + // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, + // i, element_offset[i], curR, curS, curC); + dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -263,6 +274,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( // } // thread_row += ROW_STEP; // } + return curC; #else GGML_UNUSED(src); GGML_UNUSED(dst); @@ -276,17 +288,18 @@ template -__device__ __forceinline__ void tileMemcpyLoadA( +__device__ __forceinline__ unsigned int tileMemcpyLoadA( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int curR, const unsigned int curS, - const unsigned int masks[][2], + unsigned int masks[][2], const int64_t element_offset[], const unsigned int thread_idx, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, + unsigned int oldC, const unsigned int inChannelOffset, param_t param ){ @@ -301,7 +314,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage @@ -313,13 +326,18 @@ __device__ __forceinline__ void tileMemcpyLoadA( // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = start_k+block_k+thread_col*8;; + const unsigned int curC = start_k+block_k+thread_col*8; + if (curC > oldC) + clear_mask(masks, curC >= end_k); #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); - if (valid && curC < end_k) { - dst_reg[i] = reinterpret_cast(&src[element_offset[i]])[0]; + // if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ + // printf(" %u, %u, %u, %u, %u, %lld, %d\n", i, curR, curS, oldC, curC, element_offset[i], valid?1:0); + // } + if (valid) { + dst_reg[i] = reinterpret_cast(&src[element_offset[i]+curC])[0]; } else{ dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -334,6 +352,17 @@ __device__ __forceinline__ void tileMemcpyLoadA( // // unsigned int inOffset = n * param.c * param.h * param.w; // int curH = posh_ori + curR * param.d_h; // input h // int curW = posw_ori + curS * param.d_w; // input w + // bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + // bool ovl = curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k; + // const int txx = curH * (int) inChannelOffset + curW * (int)param.c + (int)curC; + + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ + // printf(" %u, %u, %u, %u, %u, %lld, %lld, %d, %d, %d\n", i, curR, curS, oldC, curC, + // element_offset[i], element_offset[i]+(int64_t)curC, n * (int)chw + txx, + // valid?1:0, ovl?1:0); + // } + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; @@ -343,6 +372,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( // } // thread_row += ROW_STEP; // } + return curC; #else GGML_UNUSED(src); GGML_UNUSED(dst_reg); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index e3968f28b8..d0f67aa53b 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -716,15 +716,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } + // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 0cb1ff419a4a4856bce3c16bc5a5d0d45af3df3e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 14 Nov 2025 12:02:13 -0500 Subject: [PATCH 091/122] move some register to const memory space --- ggml/src/ggml-cuda/conv2d-implicit.cu | 50 ++++++++++++++++---------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 18 ++++++---- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1b49a9beb1..d7ef4b5d95 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -781,13 +781,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; - const unsigned int K = param.c; - const uint inChannelOffset = param.c * param.w; - const uint weightKOffset = param.c * param.r * param.s; + // const unsigned int K = param.c; + // const uint inChannelOffset = param.c * param.w; + // const uint weightKOffset = param.c * param.r * param.s; - const unsigned int PQ = param.Ow * param.Oh; - const unsigned int KPQ = param.k * PQ; - const unsigned int NKPQ = param.n * KPQ; + // const unsigned int PQ = param.Ow * param.Oh; + // const unsigned int KPQ = param.k * PQ; + // const unsigned int NKPQ = param.n * KPQ; // loop bounds, constexpr where possible allows for loop unrolling #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE @@ -799,9 +799,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; - const unsigned int ks = (ksplit > 0) ? (K + ksplit - 1) / ksplit : K; + const unsigned int ks = (ksplit > 0) ? (param.c + ksplit - 1) / ksplit : param.c; const unsigned int start_k = (ksplit > 0) ? z * ks : 0; - const unsigned int end_k = min(start_k + ks, K); + const unsigned int end_k = min(start_k + ks, param.c); const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; const unsigned int num_block_tiles_krs = num_block_tiles_k * param.r * param.s; @@ -888,11 +888,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory const half* A_block_gmem = input; - const half* B_block_gmem = kernel + block_n * BN * weightKOffset; + const half* B_block_gmem = kernel + block_n * BN * param.weightKOffset; unsigned int curC = tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, - thread_idx, start_k, end_k, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param); + thread_idx, start_k, end_k, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, param); int offset_direction = 1; unsigned int block_k = 0; @@ -945,8 +945,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_krs != num_block_tiles_krs){ curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, masks_a, element_offset_a, thread_idx, block_k * BK, - start_k, end_k, curC, inChannelOffset, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param); + start_k, end_k, curC, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, param); } half* A_warp_tile = A_block_smem + A_warp_tile_offset; @@ -1064,12 +1064,12 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*16*BN])); // 32*BN/2 = 16*BN half (&res_)[2] = reinterpret_cast(dst_ptr); - if (n < param.n && row < param.k && col < PQ) { - const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + row * PQ + col; + if (n < param.n && row < param.k && col < param.PQ) { + const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + row * param.PQ + col; output[outOffset] = ggml_cuda_cast(res_[0]); } - if (n < param.n && row+1 < param.k && col < PQ) { - const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + (row+1) * PQ + col; + if (n < param.n && row+1 < param.k && col < param.PQ) { + const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + (row+1) * param.PQ + col; output[outOffset] = ggml_cuda_cast(res_[1]); } } @@ -1389,6 +1389,14 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint B = input->ne[3]; // n_batches + int64_t pp[3]; + // const unsigned int K = param.c; +// const uint inChannelOffset = param.c * param.w; +// const uint weightKOffset = param.c * param.r * param.s; +// const unsigned int PQ = param.Ow * param.Oh; +// const unsigned int KPQ = param.k * PQ; +// const unsigned int NKPQ = param.n * KPQ; + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW, init_fastdiv_values(KW*IC), @@ -1396,7 +1404,13 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * init_fastdiv_values(IC), init_fastdiv_values(KW*KH), init_fastdiv_values(KW), - init_fastdiv_values(OW*OH)}; + init_fastdiv_values(OW*OH), + pp[0], pp[1], pp[2], + IC*IW, + IC*KW*KH, + OW*OH, + OC*OW*OH, + B*OC*OW*OH}; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 982699f969..40b1c7babe 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -24,6 +24,13 @@ typedef struct{ uint3 S_fastdiv; uint3 OHOW_fastdiv; int64_t inc_next[3]; + // unsigned int K; + unsigned int inChannelOffset; + unsigned int weightKOffset; + unsigned int PQ; + unsigned int KPQ; + unsigned int NKPQ; + } param_t; @@ -125,7 +132,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( const unsigned int curS, const unsigned int start_k, const unsigned int end_k, - const unsigned int src_stride, + // const unsigned int src_stride, param_t param ){ #if __CUDA_ARCH__ >= GGML_CUDA_TURING @@ -161,7 +168,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride + ki; + const unsigned int src_index = thread_row * param.weightKOffset + ki; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); @@ -195,7 +202,6 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( const unsigned int thread_idx, const unsigned int start_k, const unsigned int end_k, - const unsigned int inChannelOffset, param_t param ) { @@ -300,7 +306,7 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA( const unsigned int start_k, const unsigned int end_k, unsigned int oldC, - const unsigned int inChannelOffset, + // const unsigned int inChannelOffset, param_t param ){ #if __CUDA_ARCH__ >= GGML_CUDA_TURING @@ -396,7 +402,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, - const unsigned int src_stride, + // const unsigned int src_stride, param_t param ){ #if __CUDA_ARCH__ >= GGML_CUDA_TURING @@ -426,7 +432,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * src_stride + ki; + const unsigned int src_index = thread_row * param.weightKOffset + ki; if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves From b4530b4f8ba2ba430b8ffcbb1d14d13176054b9a Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 14 Nov 2025 12:11:52 -0500 Subject: [PATCH 092/122] disable m16n8k16 mma for ampere for now --- ggml/src/ggml-cuda/conv2d-implicit.cu | 37 +++++++++++++++------------ 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d7ef4b5d95..57cd116d73 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -7,6 +7,9 @@ typedef unsigned int uint; + +#define GGML_CUDA_CC_RUBIN 10000 + constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 @@ -343,7 +346,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] #else half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] @@ -351,13 +354,13 @@ __device__ __forceinline__ void ldmatrix_a( ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); #else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); #endif -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); #else uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); @@ -403,7 +406,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; // 1 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -471,7 +474,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b110000; // 2 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -537,7 +540,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; // 3 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -610,7 +613,7 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] #else half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] @@ -618,14 +621,14 @@ __device__ __forceinline__ void ldmatrix_b( ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); #else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); #endif static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN uint32_t (®_) [2][8][2] = reinterpret_cast(reg); #else uint32_t (®_) [4][8] = reinterpret_cast(reg); @@ -637,7 +640,7 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -670,7 +673,7 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -702,7 +705,7 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b110000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -734,7 +737,7 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -790,7 +793,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // const unsigned int NKPQ = param.n * KPQ; // loop bounds, constexpr where possible allows for loop unrolling -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN constexpr unsigned int mma_tiles_per_warp_k = 2; #else constexpr unsigned int mma_tiles_per_warp_k = 4; @@ -829,7 +832,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; #else @@ -839,7 +842,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); #else @@ -962,7 +965,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, " From ecbbdb6608b6c9fbd107fbf0153e16be3c0b5176 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 14 Nov 2025 13:05:31 -0500 Subject: [PATCH 093/122] reducing integer ops --- ggml/src/ggml-cuda/conv2d-implicit.cu | 20 +++++---- ggml/src/ggml-cuda/conv2d-implicit.cuh | 62 +++++++++++++++----------- 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 57cd116d73..d204807a2f 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -822,6 +822,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const unsigned int warp_m = threadIdx.y; const unsigned int warp_n = threadIdx.x / 32; const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // double buffering extern __shared__ half shmem[]; @@ -871,7 +873,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, - prepareIteratorA(thread_idx, masks_a, element_offset_a, param); + prepareIteratorA(thread_row, masks_a, element_offset_a, param); // for(int kk =0; kk < A_K_STRID; kk++){ // if(element_offset_a[kk] >= 327680) @@ -894,8 +896,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half* B_block_gmem = kernel + block_n * BN * param.weightKOffset; unsigned int curC = tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, - thread_idx, start_k, end_k, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, param); + thread_row, thread_col, start_k, end_k, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, thread_row, thread_col, param); int offset_direction = 1; unsigned int block_k = 0; @@ -947,9 +949,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if (block_k != num_block_tiles_k){ if (block_krs != num_block_tiles_krs){ curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, - masks_a, element_offset_a, thread_idx, block_k * BK, + masks_a, element_offset_a, thread_row, thread_col, block_k * BK, start_k, end_k, curC, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, + start_k, end_k, thread_row, thread_col, param); } half* A_warp_tile = A_block_smem + A_warp_tile_offset; @@ -1002,8 +1005,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; offset_direction = -1 * offset_direction; - tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem); - tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem, thread_row, thread_col); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem, thread_row, thread_col); } block_krs++; @@ -1413,7 +1416,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * IC*KW*KH, OW*OH, OC*OW*OH, - B*OC*OW*OH}; + B*OC*OW*OH, + IC*IW*IH}; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 40b1c7babe..9f817a0078 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -30,7 +30,7 @@ typedef struct{ unsigned int PQ; unsigned int KPQ; unsigned int NKPQ; - + unsigned int CHW; } param_t; @@ -58,7 +58,7 @@ template -__device__ void prepareIteratorA(const int thread_idx, +__device__ void prepareIteratorA(unsigned int thread_row, unsigned int masks[][2], int64_t element_offset[], const param_t param){ @@ -67,8 +67,8 @@ __device__ void prepareIteratorA(const int thread_idx, int offset_q[A_K_STRID]; constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - const unsigned int chw = param.c * param.h * param.w; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int chw = param.c * param.h * param.w; #pragma unroll for (int s = 0; s < A_K_STRID; ++s) { @@ -91,7 +91,7 @@ __device__ void prepareIteratorA(const int thread_idx, // offset_npq, offset_n[s], offset_p[s], offset_q[s], AccessType::kElements, // ThreadMap::Iterations::kContiguous); - element_offset[s] = offset_n[s] * (int64_t)chw + h * (int64_t)(param.c * param.w) + w * (int64_t)param.c; + element_offset[s] = offset_n[s] * (int64_t)param.CHW + h * (int64_t)(param.inChannelOffset) + w * (int64_t)param.c; // if(element_offset[s] >= 327680) // printf("(%d, %d, %d, %d, %d), %d, %lld, %d, %d, %d, %d, %d, %u, %u, %u \n", @@ -126,12 +126,14 @@ __device__ void prepareIteratorA(const int thread_idx, template __device__ __forceinline__ void tileMemcpySwizzleB( - const half* src, - half* dst, + const half* __restrict__ src, + half* __restrict__ dst, const unsigned int curR, const unsigned int curS, const unsigned int start_k, const unsigned int end_k, + unsigned int thread_row, + const unsigned int thread_col, // const unsigned int src_stride, param_t param ){ @@ -149,14 +151,14 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // const unsigned int ki = (curR*param.s+curS)*param.c + start_k+thread_col*8; // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset @@ -193,13 +195,14 @@ __device__ __forceinline__ void tileMemcpySwizzleB( template __device__ __forceinline__ unsigned int tileMemcpySwizzleA( - const half* src, - half* dst, + const half* __restrict__ src, + half* __restrict__ dst, const unsigned int curR, const unsigned int curS, unsigned int masks[][2], const int64_t element_offset[], - const unsigned int thread_idx, + unsigned int thread_row, + const unsigned int thread_col, const unsigned int start_k, const unsigned int end_k, param_t param @@ -225,8 +228,8 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // const unsigned int ki = start_k+thread_col*8; // const unsigned int chw = param.c * param.h * param.w; @@ -295,13 +298,14 @@ unsigned int TILE_COLS, unsigned int NUM_THREADS, unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ unsigned int tileMemcpyLoadA( - const half* src, + const half* __restrict__ src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int curR, const unsigned int curS, unsigned int masks[][2], const int64_t element_offset[], - const unsigned int thread_idx, + unsigned int thread_row, + const unsigned int thread_col, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -320,8 +324,8 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA( // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); @@ -395,13 +399,15 @@ unsigned int TILE_COLS, unsigned int NUM_THREADS, unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadB( - const half* src, + const half* __restrict__ src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int curR, const unsigned int curS, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, + unsigned int thread_row, + const unsigned int thread_col, // const unsigned int src_stride, param_t param ){ @@ -412,14 +418,14 @@ __device__ __forceinline__ void tileMemcpyLoadB( static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); @@ -459,7 +465,9 @@ unsigned int NUM_THREADS, unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpySwizzleStore( const float4 (&src_reg)[ELEMENTS_PER_THREAD], - half* dst + half* __restrict__ dst, + unsigned int thread_row, + const unsigned int thread_col ) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING @@ -478,14 +486,14 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); From e4fbece60685542012ddd5dece5b46a72404615d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 14 Nov 2025 13:51:07 -0500 Subject: [PATCH 094/122] various small optimizations --- ggml/src/ggml-cuda/conv2d-implicit.cu | 21 +++++++++++++++---- ggml/src/ggml-cuda/conv2d-implicit.cuh | 28 ++++++++++++++++++-------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d204807a2f..5bb5cd7cad 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -811,7 +811,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8; constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int A_K_STRID = BM / ROW_STEP; - constexpr unsigned int B_K_STRID = BN / ROW_STEP; + // constexpr unsigned int B_K_STRID = BN / ROW_STEP; unsigned int masks_a[A_K_STRID][2]; int64_t element_offset_a[A_K_STRID]; @@ -1263,9 +1263,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { if (BlocksM * BlocksN < 2*(unsigned int)nsm){ int j, max_remaining_waves = -1, candidate = -1; - int ks = min(16, nsm / (BlocksM * BlocksN)); + int ks = min(20, nsm / (BlocksM * BlocksN)); if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) - ks = 16; + ks = 20; for (j = 2; j <= ks; j++){ const int remainder = (BlocksM * BlocksN * j) % nsm; // if ((P.c * P.r * P.s) % (8*j) == 0){ @@ -1328,7 +1328,20 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } else if (j == 16) { launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 17) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 18) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 19) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 20) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } + return; } } @@ -1395,7 +1408,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint B = input->ne[3]; // n_batches - int64_t pp[3]; + int64_t pp[3] = {0}; // const unsigned int K = param.c; // const uint inChannelOffset = param.c * param.w; // const uint weightKOffset = param.c * param.r * param.s; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 9f817a0078..924b678b81 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -37,7 +37,8 @@ typedef struct{ /// Clears the predicates template -__host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { +// __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { +__device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { #pragma unroll for (int s = 0; s < K_STRID; ++s) { @@ -47,7 +48,8 @@ __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) } template -__host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ +// __host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ +__device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ #pragma unroll for (int s = 0; s < K_STRID; ++s) { element_offset[s] += offset; @@ -66,7 +68,7 @@ __device__ void prepareIteratorA(unsigned int thread_row, int offset_p[A_K_STRID]; int offset_q[A_K_STRID]; - constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + // constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; // const unsigned int chw = param.c * param.h * param.w; @@ -436,15 +438,22 @@ __device__ __forceinline__ void tileMemcpyLoadB( const unsigned int curC = start_k+block_k+thread_col*8; const unsigned int ki = (curR*param.s+curS)*param.c + curC; + unsigned int iter_idx = thread_row * param.weightKOffset + ki; + unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; + const int ITER_STEPS = ROW_STEP * param.weightKOffset; + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * param.weightKOffset + ki; - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ + // const unsigned int src_index = thread_row * param.weightKOffset + ki; + const unsigned int src_index = iter_idx; + // if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ + if (krow_idx < param.k && curC < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } - thread_row += ROW_STEP; + krow_idx += ROW_STEP; + iter_idx += ITER_STEPS; } #else GGML_UNUSED(src); @@ -492,21 +501,24 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { // apply swizzle to the dst index - unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + unsigned int dst_index = iter_idx; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); dst_float4[dst_index] = src_reg[i]; - thread_row += ROW_STEP; + iter_idx += ITER_STEPS; } #else GGML_UNUSED(src_reg); From 11bd9806bfe99fa030070efdaf71e600a6acc2cb Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 14 Nov 2025 17:01:24 -0500 Subject: [PATCH 095/122] add/fix GGML_UNUSED --- ggml/src/ggml-cuda/conv2d-implicit.cuh | 35 +++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 924b678b81..b0d8c17a50 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -186,7 +186,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB( #else GGML_UNUSED(src); GGML_UNUSED(dst); - GGML_UNUSED(src_stride); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); GGML_UNUSED(param); NO_DEVICE_CODE; #endif @@ -289,7 +294,14 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( #else GGML_UNUSED(src); GGML_UNUSED(dst); - GGML_UNUSED(inChannelOffset); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(masks); + GGML_UNUSED(element_offset); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); GGML_UNUSED(param); NO_DEVICE_CODE; #endif @@ -389,7 +401,15 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA( GGML_UNUSED(src); GGML_UNUSED(dst_reg); GGML_UNUSED(block_k); - GGML_UNUSED(inChannelOffset); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(masks); + GGML_UNUSED(element_offset); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(oldC); GGML_UNUSED(param); NO_DEVICE_CODE; #endif @@ -459,7 +479,12 @@ __device__ __forceinline__ void tileMemcpyLoadB( GGML_UNUSED(src); GGML_UNUSED(dst_reg); GGML_UNUSED(block_k); - GGML_UNUSED(src_stride); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); GGML_UNUSED(param); NO_DEVICE_CODE; #endif @@ -523,6 +548,8 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( #else GGML_UNUSED(src_reg); GGML_UNUSED(dst); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); NO_DEVICE_CODE; #endif } From 378bb8368e13b194d1f670b7ff3d15a14a663b0f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 14 Nov 2025 18:48:06 -0500 Subject: [PATCH 096/122] WIP: adding cp.async calls --- ggml/src/ggml-cuda/conv2d-implicit.cuh | 28 ++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index b0d8c17a50..f2c3d60998 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -176,11 +176,25 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + unsigned int smem_ptr; + void *ptr = (void *)(dst); + int src_in_bytes = thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k ? 16 : 0; + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " + "%0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), + "l"(&src[src_index]), + "n"(16), "r"(src_in_bytes)); +#else if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } +#endif thread_row += ROW_STEP; } #else @@ -257,6 +271,19 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( // printf(" %u, %u, %u, %u, %lld, %d\n", i, curR, curS, curC, element_offset[i], valid?1:0); // } // if (valid && curC < end_k){ +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + unsigned int smem_ptr; + void *ptr = (void *)(dst); + int src_in_bytes = valid ? 16 : 0; + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " + "%0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), + "l"(&src[element_offset[i]+curC]), + "n"(16), "r"(src_in_bytes)); +#else if (valid){ // if(element_offset[i] >= 327680 || element_offset[i] < 0) // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, @@ -265,6 +292,7 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } +#endif thread_row += ROW_STEP; } // #pragma unroll From dbeb6ced466547cff5d5de0a010d7d4d322d95fd Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 15 Nov 2025 00:18:26 -0500 Subject: [PATCH 097/122] WIP: debugging --- ggml/src/ggml-cuda/conv2d-implicit.cu | 89 +++++++-- ggml/src/ggml-cuda/conv2d-implicit.cuh | 250 ++++++++++++++++++++----- tests/test-conv2d.cpp | 18 +- 3 files changed, 285 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 5bb5cd7cad..d21e13d5ea 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -831,6 +831,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, half* B_block_smem = &shmem[BM * BK]; constexpr int BUFFER_SIZE = BM * BK + BK * BN; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + half* SA1 = A_block_smem; + half* SB1 = B_block_smem; + half* SA2 = &shmem[BUFFER_SIZE]; + half* SB2 = SA2 + BM * BK; +#else + float4 A_gmem_cache_reg[4]; + float4 B_gmem_cache_reg[4]; +#endif // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; @@ -868,9 +877,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, static_assert(BN == 256); static_assert(BK == 32); static_assert(NUM_THREADS == 256); - float4 A_gmem_cache_reg[4]; - float4 B_gmem_cache_reg[4]; - prepareIteratorA(thread_row, masks_a, element_offset_a, param); @@ -898,7 +904,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, unsigned int curC = tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, thread_row, thread_col, start_k, end_k, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, thread_row, thread_col, param); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm volatile("cp.async.commit_group;\n" ::); +#endif int offset_direction = 1; unsigned int block_k = 0; unsigned int block_krs = 1; @@ -906,6 +914,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, int s = 0; int r = 0; while (block_k < num_block_tiles_k){ + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); __syncthreads(); // moves to the next tile @@ -948,15 +957,29 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if (block_k != num_block_tiles_k){ if (block_krs != num_block_tiles_krs){ +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + curC = tileMemcpyAsyncLoadA(A_block_gmem, SA2, r, s, + masks_a, element_offset_a, thread_row, thread_col, block_k * BK, + start_k, end_k, curC, param); + tileMemcpyAsyncLoadB(B_block_gmem, SB2, r, s, block_k * BK, + start_k, end_k, thread_row, thread_col, param); + asm volatile("cp.async.commit_group;\n" ::); +#else curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, masks_a, element_offset_a, thread_row, thread_col, block_k * BK, start_k, end_k, curC, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, thread_row, thread_col, param); +#endif } +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + half* A_warp_tile = SA1 + A_warp_tile_offset; + half* B_warp_tile = SB1 + B_warp_tile_offset; +#else half* A_warp_tile = A_block_smem + A_warp_tile_offset; half* B_warp_tile = B_block_smem + B_warp_tile_offset; +#endif ldmatrix_a(A_warp_tile, A_register_); ldmatrix_b(B_warp_tile, B_register_); @@ -998,8 +1021,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } // if (block_k != num_block_tiles_k) - if (block_krs != num_block_tiles_krs) - { + if (block_krs != num_block_tiles_krs) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + half *tmp = SA1; SA1 = SA2; SA2 = tmp; + tmp = SB1; SB1 = SB2; SB2 = tmp; +#else // switch smem buffers each iteration A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction; B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; @@ -1007,15 +1033,56 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem, thread_row, thread_col); tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem, thread_row, thread_col); +#endif } - block_krs++; - } - // A_block_smem = shmem; - // B_block_smem = &shmem[BM * BK]; - // } // iter block_k + +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); + __syncthreads(); + half* A_warp_tile = SA2 + A_warp_tile_offset; + half* B_warp_tile = SB2 + B_warp_tile_offset; + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); + // outer product between mma tiles +#pragma unroll + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++){ +#pragma unroll + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ +#pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN + asm volatile ( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3, %4, %5}, " + "{%6, %7}, " + "{%8, %9};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]), + "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); +#else + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); +#endif + } + } + } +#endif + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ // printf(" %u, %f\n", blockIdx.z, __half2float(acc_register_[0][0][0])); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index f2c3d60998..4728e0e757 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -124,6 +124,28 @@ __device__ void prepareIteratorA(unsigned int thread_row, } } +template +__device__ void cp_async_zfill(void *ptr, void const *global_ptr, bool pred_guard = true) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + + unsigned int smem_ptr; + int src_in_bytes = pred_guard ? preload : 0; + + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " + "%0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), + "l"(global_ptr), + "n"(preload), "r"(src_in_bytes)); +#else + GGML_UNUSED(ptr); + GGML_UNUSED(global_ptr); + GGML_UNUSED(pred_guard); +#endif +} + // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template @@ -177,17 +199,10 @@ __device__ __forceinline__ void tileMemcpySwizzleB( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - unsigned int smem_ptr; - void *ptr = (void *)(dst); - int src_in_bytes = thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k ? 16 : 0; - asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " - "%0, smem_ptr; }\n" - : "=r"(smem_ptr) - : "l"(ptr)); - asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), - "l"(&src[src_index]), - "n"(16), "r"(src_in_bytes)); + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]), + thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k); + #else if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; @@ -272,24 +287,14 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( // } // if (valid && curC < end_k){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - unsigned int smem_ptr; - void *ptr = (void *)(dst); - int src_in_bytes = valid ? 16 : 0; - asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " - "%0, smem_ptr; }\n" - : "=r"(smem_ptr) - : "l"(ptr)); - - asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), - "l"(&src[element_offset[i]+curC]), - "n"(16), "r"(src_in_bytes)); + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid); #else if (valid){ // if(element_offset[i] >= 327680 || element_offset[i] < 0) // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, // i, element_offset[i], curR, curS, curC); dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; - } else{ + } else { dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } #endif @@ -394,36 +399,6 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA( dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } } - // #pragma unroll - // for (unsigned int i = 0; i < NUM_ITERS; i++){ - // unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - // // unsigned int inOffset = n * param.c * param.h * param.w; - // int curH = posh_ori + curR * param.d_h; // input h - // int curW = posw_ori + curS * param.d_w; // input w - // bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); - // bool ovl = curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k; - // const int txx = curH * (int) inChannelOffset + curW * (int)param.c + (int)curC; - - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ - // printf(" %u, %u, %u, %u, %u, %lld, %lld, %d, %d, %d\n", i, curR, curS, oldC, curC, - // element_offset[i], element_offset[i]+(int64_t)curC, n * (int)chw + txx, - // valid?1:0, ovl?1:0); - // } - - // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ - // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - // dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; - // } else{ - // dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); - // } - // thread_row += ROW_STEP; - // } return curC; #else GGML_UNUSED(src); @@ -443,6 +418,93 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA( #endif } +template +__device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( + const half* __restrict__ src, + half* __restrict__ dst, + const unsigned int curR, + const unsigned int curS, + unsigned int masks[][2], + const int64_t element_offset[], + unsigned int thread_row, + const unsigned int thread_col, + const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, + unsigned int oldC, + // const unsigned int inChannelOffset, + param_t param +){ +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + // # of threads is multiple of # of columns in the tile + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + float4* dst_float4 = reinterpret_cast(dst); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + // const unsigned int ki = start_k+block_k+thread_col*8; + // const unsigned int chw = param.c * param.h * param.w; + + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = start_k+block_k+thread_col*8; + if (curC > oldC) + clear_mask(masks, curC >= end_k); + + unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++){ + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + // if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ + // printf(" %u, %u, %u, %u, %u, %lld, %d\n", i, curR, curS, oldC, curC, element_offset[i], valid?1:0); + // } + unsigned int dst_index = iter_idx; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid); + iter_idx += ITER_STEPS; + } + return curC; +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(block_k); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(masks); + GGML_UNUSED(element_offset); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(oldC); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + template= GGML_CUDA_TURING + + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); @@ -518,6 +586,84 @@ __device__ __forceinline__ void tileMemcpyLoadB( #endif } +template +__device__ __forceinline__ void tileMemcpyAsyncLoadB( + const half *src, + half *dst, + const unsigned int curR, + const unsigned int curS, + const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, + unsigned int thread_row, + const unsigned int thread_col, + param_t param +){ + +#if __CUDA_ARCH__ >= GGML_CUDA_AMPERE + + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + float4* dst_float4 = reinterpret_cast(dst); + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + constexpr unsigned int ITER_DST_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + const unsigned int curC = start_k+block_k+thread_col*8; + const unsigned int ki = (curR*param.s+curS)*param.c + curC; + + unsigned int iter_src_idx = thread_row * param.weightKOffset + ki; + unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; + unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; + const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++){ + // const unsigned int src_index = thread_row * param.weightKOffset + ki; + const unsigned int src_index = iter_src_idx; + unsigned int dst_index = iter_dst_idx; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]), krow_idx < param.k && curC < end_k); + + iter_src_idx += ITER_SRC_STEPS; + krow_idx += ROW_STEP; + iter_dst_idx += ITER_DST_STEPS; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(block_k); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + // same as above but without the swizzle diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index d0f67aa53b..e3968f28b8 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -716,15 +716,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From e10b495dd20dfc5bd91bc5c34a86ca2aab90c993 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 15 Nov 2025 01:24:09 -0500 Subject: [PATCH 098/122] add the missing guard --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d21e13d5ea..902220b74f 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -914,7 +914,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, int s = 0; int r = 0; while (block_k < num_block_tiles_k){ + #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); + #endif __syncthreads(); // moves to the next tile From e489dd277352edb6f73be5ceed0ea1e67dd11931 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 15 Nov 2025 09:58:23 -0500 Subject: [PATCH 099/122] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 18 ++++++---- tests/test-backend-ops.cpp | 5 ++- tests/test-conv2d.cpp | 51 ++++++++++++++++++--------- 3 files changed, 49 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 902220b74f..ec975194cc 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -913,13 +913,16 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){ int s = 0; int r = 0; - while (block_k < num_block_tiles_k){ - #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + while (block_krs < num_block_tiles_krs) { + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); - #endif +#else + while (block_k < num_block_tiles_k) { +#endif __syncthreads(); - // moves to the next tile + // moves to the next channel block tile int next_idx = 0; ++s; if (s == param.s) { @@ -954,7 +957,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // break; // if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d %d \n", s, r, block_k, next_idx, block_krs, num_block_tiles_k); + // printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx, + // block_krs, num_block_tiles_k, num_block_tiles_krs); // } // if (block_k != num_block_tiles_k){ @@ -1044,8 +1048,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); __syncthreads(); - half* A_warp_tile = SA2 + A_warp_tile_offset; - half* B_warp_tile = SB2 + B_warp_tile_offset; + half* A_warp_tile = SA1 + A_warp_tile_offset; + half* B_warp_tile = SB1 + B_warp_tile_offset; ldmatrix_a(A_warp_tile, A_register_); ldmatrix_b(B_warp_tile, B_register_); // outer product between mma tiles diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 16861c71c9..7a1fe441ff 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5826,7 +5826,7 @@ static std::vector> make_test_cases_eval() { for (uint32_t s0 : { 1, 3 }) { for (uint32_t p1 : { 2, 5 }) { - for (uint32_t Cin : { 1, 25 }) { + for (uint32_t Cin : { 1, 25, 32 }) { for (uint32_t Cout : { 1, 12 }) { for (uint32_t KH : { 1, 2, 3, 11 }) { for (uint32_t KW : { 1, 2, 3, 11 }) { @@ -5854,6 +5854,9 @@ static std::vector> make_test_cases_eval() { GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 1 }, { 3, 3, 128, 8}, GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 3 }, { 3, 3, 128, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + // sycl backend will limit task global_range < MAX_INT diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index e3968f28b8..b87a04c858 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -43,7 +43,7 @@ struct ggml_cgraph * build_graph_1(const test_model&); void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { // create data int KW = kw, KH = kh, IC = ic, OC = oc; - int IW = iw, IH = ih, N = 1; + int IW = iw, IH = ih, N = 2; // srand(time(NULL)); // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); @@ -176,12 +176,19 @@ struct ggml_cgraph * build_graph_0(const test_model& model) { struct ggml_cgraph * gf = ggml_new_graph(ctx0); - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; + // int s0 = 1; + // int s1 = 1; + // int p0 = 1; + // int p1 = 1; + // int d0 = 1; + // int d1 = 1; + + int s0 = 3; + int s1 = 5; + int p0 = 5; + int p1 = 5; + int d0 = 2; + int d1 = 4; // recalculate for avoid fragmentation struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); @@ -215,12 +222,21 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { struct ggml_cgraph * gf = ggml_new_graph(ctx0); - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; + // int s0 = 1; + // int s1 = 1; + // int p0 = 1; + // int p1 = 1; + // int d0 = 1; + // int d1 = 1; + + + int s0 = 3; + int s1 = 5; + int p0 = 5; + int p1 = 5; + int d0 = 2; + int d1 = 4; + // recalculate for avoid fragmentation // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); @@ -301,7 +317,8 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(32,12,141,133,3,3), // std::make_tuple(32,8,24,24,3,3), // std::make_tuple(640,640,64,64,3,3), // std::make_tuple(320,640,32,32,3,3), @@ -718,12 +735,12 @@ int main(void) // for(int i = 0; i < 26*38; i++) { for(int i = 0; i < conv2d_data.size(); i++) { float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { + if(diff > 0.5) { printf("(%7.3f, %7.3f, %.2f, %d) \n", im2col_data[i], conv2d_data[i], diff, i); - // break; - // } + break; + } } ggml_free(model.ctx); From fa7dd684bf34fd55be165dca3fc899d525fb81f9 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 15 Nov 2025 14:45:01 -0500 Subject: [PATCH 100/122] not working properly for channel numbers of 32, 48, 96 etc., ok for 64, 128... --- ggml/src/ggml-cuda/conv2d-implicit.cu | 86 +++++++++++++++++++++++-- ggml/src/ggml-cuda/conv2d-implicit.cuh | 2 +- tests/test-backend-ops.cpp | 2 +- tests/test-conv2d.cpp | 89 +++++++++++++++----------- 4 files changed, 134 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index ec975194cc..b665ae4dfa 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -956,10 +956,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if(block_k == num_block_tiles_k) // break; - // if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx, - // block_krs, num_block_tiles_k, num_block_tiles_krs); - // } + if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx, + block_krs, num_block_tiles_k, num_block_tiles_krs); + } // if (block_k != num_block_tiles_k){ if (block_krs != num_block_tiles_krs){ @@ -1024,8 +1024,47 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #endif } } + + + // if(threadIdx.x >= 8 && threadIdx.x < 12 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("A %d, %d, %d: %f, %f \n", block_krs, mma_k, threadIdx.x, + // __half2float(A_register_[1][mma_k][0]), + // __half2float(A_register_[1][mma_k][1])); + // } + // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("B %d, %d, %d: %f, %f\n", block_krs, mma_k, threadIdx.x, + // __half2float(B_register_[mma_k][1][0]), + // __half2float(B_register_[mma_k][1][1])); + // } + // if(threadIdx.x == 8 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("C %d, %d, %d: %f, %f, %f, %f\n", block_krs, mma_k, threadIdx.x, + // __half2float(acc_register_[1][1][0]), + // __half2float(acc_register_[1][1][1]), + // __half2float(acc_register_[1][1][2]), + // __half2float(acc_register_[1][1][3])); + // } + + // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("A %d, %d, (%d, %d) %d: %f, %f \n", block_krs, mma_k, r, s, threadIdx.x, + // __half2float(A_register_[0][mma_k][0]), + // __half2float(A_register_[0][mma_k][1])); + // } + // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("B %d, %d, (%d, %d) %d: %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, + // __half2float(B_register_[mma_k][0][0]), + // __half2float(B_register_[mma_k][0][1])); + // } + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("C %d, %d, (%d, %d) %d: %f, %f, %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, + // __half2float(acc_register_[0][0][0]), + // __half2float(acc_register_[0][0][1]), + // __half2float(acc_register_[0][0][2]), + // __half2float(acc_register_[0][0][3])); + // } + } + // if (block_k != num_block_tiles_k) if (block_krs != num_block_tiles_krs) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE @@ -1086,13 +1125,41 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #endif } } + // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("A %d, %d, (%d, %d) %d: %f, %f \n", block_krs, mma_k, r, s, threadIdx.x, + // __half2float(A_register_[0][mma_k][0]), + // __half2float(A_register_[0][mma_k][1])); + // } + // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("B %d, %d, (%d, %d) %d: %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, + // __half2float(B_register_[mma_k][0][0]), + // __half2float(B_register_[mma_k][0][1])); + // } + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf("C %d, %d, (%d, %d) %d: %f, %f, %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, + // __half2float(acc_register_[0][0][0]), + // __half2float(acc_register_[0][0][1]), + // __half2float(acc_register_[0][0][2]), + // __half2float(acc_register_[0][0][3])); + // } } #endif - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf(" %u, %f\n", blockIdx.z, __half2float(acc_register_[0][0][0])); + // if(threadIdx.x == 8 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf(" %u, %f, %f, %f, %f\n", blockIdx.z, + // __half2float(acc_register_[1][1][0]), + // __half2float(acc_register_[1][1][1]), + // __half2float(acc_register_[1][1][2]), + // __half2float(acc_register_[1][1][3])); // } + if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + printf(" %u, %f, %f, %f, %f\n", blockIdx.z, + __half2float(acc_register_[0][1][0]), + __half2float(acc_register_[0][1][1]), + __half2float(acc_register_[0][1][2]), + __half2float(acc_register_[0][1][3])); + } // reuse smem half *smemoutput = shmem; @@ -1145,10 +1212,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < param.PQ) { const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + row * param.PQ + col; + // if(row == 8 && col == 18) + // printf("A %u, %u, %f \n", outOffset, z, ggml_cuda_cast(res_[0])); output[outOffset] = ggml_cuda_cast(res_[0]); } if (n < param.n && row+1 < param.k && col < param.PQ) { const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + (row+1) * param.PQ + col; + // if(row+1 == 8 && col == 17) + // printf("B %u, %u, %f \n", outOffset, z, ggml_cuda_cast(res_[0])); output[outOffset] = ggml_cuda_cast(res_[1]); } } @@ -1353,9 +1424,10 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } } } - + candidate = -1; if(candidate != -1){ j = candidate; + printf("choosing %d \n", j); if (j == 2) { launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 4728e0e757..ee56c80b7f 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -476,7 +476,7 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); - // if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ // printf(" %u, %u, %u, %u, %u, %lld, %d\n", i, curR, curS, oldC, curC, element_offset[i], valid?1:0); // } unsigned int dst_index = iter_idx; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7a1fe441ff..171c500668 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5826,7 +5826,7 @@ static std::vector> make_test_cases_eval() { for (uint32_t s0 : { 1, 3 }) { for (uint32_t p1 : { 2, 5 }) { - for (uint32_t Cin : { 1, 25, 32 }) { + for (uint32_t Cin : { 1, 25 }) { for (uint32_t Cout : { 1, 12 }) { for (uint32_t KH : { 1, 2, 3, 11 }) { for (uint32_t KW : { 1, 2, 3, 11 }) { diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index b87a04c858..8187d5f5fd 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -43,7 +43,7 @@ struct ggml_cgraph * build_graph_1(const test_model&); void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { // create data int KW = kw, KH = kh, IC = ic, OC = oc; - int IW = iw, IH = ih, N = 2; + int IW = iw, IH = ih, N = 1; // srand(time(NULL)); // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); @@ -53,6 +53,8 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, for (int i = 0; i < KW * KH * IC * OC; i++) { // adata[i] = 2.f; // adata[i] = (float)(i%KW)-1.f; + // adata[i] = (float)((i+1)%KW+1)/10.0; + // adata[i] = (float)(i%100); // adata[i] = (rand() % 255) / 255.0; float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); adata[i] = r; @@ -176,19 +178,19 @@ struct ggml_cgraph * build_graph_0(const test_model& model) { struct ggml_cgraph * gf = ggml_new_graph(ctx0); - // int s0 = 1; - // int s1 = 1; - // int p0 = 1; - // int p1 = 1; - // int d0 = 1; - // int d1 = 1; + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; - int s0 = 3; - int s1 = 5; - int p0 = 5; - int p1 = 5; - int d0 = 2; - int d1 = 4; + // int s0 = 3; + // int s1 = 5; + // int p0 = 5; + // int p1 = 5; + // int d0 = 2; + // int d1 = 4; // recalculate for avoid fragmentation struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); @@ -222,20 +224,20 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { struct ggml_cgraph * gf = ggml_new_graph(ctx0); - // int s0 = 1; - // int s1 = 1; - // int p0 = 1; - // int p1 = 1; - // int d0 = 1; - // int d1 = 1; + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; - int s0 = 3; - int s1 = 5; - int p0 = 5; - int p1 = 5; - int d0 = 2; - int d1 = 4; + // int s0 = 3; + // int s1 = 5; + // int p0 = 5; + // int p1 = 5; + // int d0 = 2; + // int d1 = 4; // recalculate for avoid fragmentation @@ -318,7 +320,21 @@ static std::vector> configs = { // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) // std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(32,12,141,133,3,3), + // std::make_tuple(32,12,141,133,3,3), + // std::make_tuple(32,6,141,133,3,3), + // std::make_tuple(32,12,141,121,3,3), + // std::make_tuple(32,9,141,121,3,3), + // std::make_tuple(320,8,16,16,3,3), //working + // std::make_tuple(320,9,16,16,3,3), //working + // std::make_tuple(320,12,16,16,3,3), //working + // std::make_tuple(256,12,16,16,3,3), //working + // std::make_tuple(32,12,16,16,3,3), //not working + // std::make_tuple(48,12,16,16,3,3), // not working + std::make_tuple(96,12,16,16,3,3), //not working + // std::make_tuple(64,12,16,16,3,3), //working + // std::make_tuple(64,12,141,133,3,3), //working + // std::make_tuple(32,12,141,133,3,3), //working + // std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(32,8,24,24,3,3), // std::make_tuple(640,640,64,64,3,3), // std::make_tuple(320,640,32,32,3,3), @@ -730,18 +746,19 @@ int main(void) run_time0, mem_size0/1024.0f/1024.0f, run_time1, mem_size1/1024.0f/1024.0f); - + // int i = 2048; // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - break; - } - } + // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // // if(diff > 2.0) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 3591e83db97486d1ca836f4fd0e5be0602968e50 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 15 Nov 2025 22:37:52 -0500 Subject: [PATCH 101/122] the special filter transpose NCHW2NHWC is broken, disable it and use the other less optimized one --- ggml/src/ggml-cuda/conv2d-implicit.cu | 160 +++++++++++++------------- tests/test-backend-ops.cpp | 2 +- tests/test-conv2d.cpp | 1 - 3 files changed, 81 insertions(+), 82 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index b665ae4dfa..e923d77e0b 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -82,47 +82,48 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } -template -static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ + //*** broken, has bugs *** +// template +// static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ - const int64_t nmat = ne / (ne00 * ne01); - const int64_t n = ne00 * ne01; +// const int64_t nmat = ne / (ne00 * ne01); +// const int64_t n = ne00 * ne01; - const unsigned int tx = threadIdx.x; - const unsigned int bx = blockIdx.x; - const unsigned int by = blockIdx.y; +// const unsigned int tx = threadIdx.x; +// const unsigned int bx = blockIdx.x; +// const unsigned int by = blockIdx.y; - __shared__ src_T tile[rs*blk_c]; +// __shared__ src_T tile[rs*blk_c]; -#pragma unroll - for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ +// #pragma unroll +// for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ - const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; - if(imat >= nmat) - break; -#pragma unroll - for (unsigned int j = 0; j < rs; j++){ - const unsigned int row = (j * blk_c + tx) % rs; - const unsigned int col = (j * blk_c + tx) / rs; - const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; - unsigned int idx = row * blk_c + col; - idx = idx ^ ((idx & mask) >> 4); - if (src_index < ne) { - tile[idx] = src[src_index]; - } - } - __syncthreads(); -#pragma unroll - for (unsigned int j = 0; j < rs; j++){ - const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; - if(dst_index < ne){ - unsigned int idx = j*blk_c + tx; - idx = idx ^ ((idx & mask) >> 4); - dst[dst_index] = ggml_cuda_cast(tile[idx]); - } - } - } -} +// const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; +// if(imat >= nmat) +// break; +// #pragma unroll +// for (unsigned int j = 0; j < rs; j++){ +// const unsigned int row = (j * blk_c + tx) % rs; +// const unsigned int col = (j * blk_c + tx) / rs; +// const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; +// unsigned int idx = row * blk_c + col; +// idx = idx ^ ((idx & mask) >> 4); +// if (src_index < ne) { +// tile[idx] = src[src_index]; +// } +// } +// __syncthreads(); +// #pragma unroll +// for (unsigned int j = 0; j < rs; j++){ +// const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; +// if(dst_index < ne){ +// unsigned int idx = j*blk_c + tx; +// idx = idx ^ ((idx & mask) >> 4); +// dst[dst_index] = ggml_cuda_cast(tile[idx]); +// } +// } +// } +// } @@ -956,10 +957,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if(block_k == num_block_tiles_k) // break; - if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx, - block_krs, num_block_tiles_k, num_block_tiles_krs); - } + // if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ + // printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx, + // block_krs, num_block_tiles_k, num_block_tiles_krs); + // } // if (block_k != num_block_tiles_k){ if (block_krs != num_block_tiles_krs){ @@ -1025,7 +1026,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } - // if(threadIdx.x >= 8 && threadIdx.x < 12 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ // printf("A %d, %d, %d: %f, %f \n", block_krs, mma_k, threadIdx.x, // __half2float(A_register_[1][mma_k][0]), @@ -1153,13 +1153,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // __half2float(acc_register_[1][1][2]), // __half2float(acc_register_[1][1][3])); // } - if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - printf(" %u, %f, %f, %f, %f\n", blockIdx.z, - __half2float(acc_register_[0][1][0]), - __half2float(acc_register_[0][1][1]), - __half2float(acc_register_[0][1][2]), - __half2float(acc_register_[0][1][3])); - } + // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ + // printf(" %u, %f, %f, %f, %f\n", blockIdx.z, + // __half2float(acc_register_[0][1][0]), + // __half2float(acc_register_[0][1][1]), + // __half2float(acc_register_[0][1][2]), + // __half2float(acc_register_[0][1][3])); + // } // reuse smem half *smemoutput = shmem; @@ -1341,42 +1341,42 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, 1) ; - if (ne01 == 25) { - constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - }else if (ne01 == 16) { - constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - }else if (ne01 == 9) { - constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else if (ne01 == 8) { - constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else if (ne01 == 7) { - constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else if (ne01 == 6) { - constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else if (ne01 == 5) { - constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else if (ne01 == 4) { - constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else if (ne01 == 3) { - constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else if (ne01 == 2) { - constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } else { + // if (ne01 == 25) { + // constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // }else if (ne01 == 16) { + // constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // }else if (ne01 == 9) { + // constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else if (ne01 == 8) { + // constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else if (ne01 == 7) { + // constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else if (ne01 == 6) { + // constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else if (ne01 == 5) { + // constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else if (ne01 == 4) { + // constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else if (ne01 == 3) { + // constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else if (ne01 == 2) { + // constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); + // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + // } else { dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - } + // } } const half *X_H = input_f16.get(); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 171c500668..b037af506b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5826,7 +5826,7 @@ static std::vector> make_test_cases_eval() { for (uint32_t s0 : { 1, 3 }) { for (uint32_t p1 : { 2, 5 }) { - for (uint32_t Cin : { 1, 25 }) { + for (uint32_t Cin : { 1, 25, 48 }) { for (uint32_t Cout : { 1, 12 }) { for (uint32_t KH : { 1, 2, 3, 11 }) { for (uint32_t KW : { 1, 2, 3, 11 }) { diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 8187d5f5fd..b7ff38b833 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -752,7 +752,6 @@ int main(void) // for(int i = 0; i < conv2d_data.size(); i++) { // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { - // // if(diff > 2.0) { // printf("(%7.3f, %7.3f, %.2f, %d) \n", // im2col_data[i], conv2d_data[i], // diff, i); From 721fa41076e9a387d4ccd77387b6ea63cca7c228 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 15 Nov 2025 23:59:38 -0500 Subject: [PATCH 102/122] restore split-k for small inputs --- ggml/src/ggml-cuda/conv2d-implicit.cu | 81 +++++++++++++-------------- tests/test-conv2d.cpp | 4 +- 2 files changed, 42 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index e923d77e0b..15cb613dc6 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -83,47 +83,48 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } //*** broken, has bugs *** -// template -// static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ -// const int64_t nmat = ne / (ne00 * ne01); -// const int64_t n = ne00 * ne01; + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; -// const unsigned int tx = threadIdx.x; -// const unsigned int bx = blockIdx.x; -// const unsigned int by = blockIdx.y; + const unsigned int tx = threadIdx.x; + const unsigned int bx = blockIdx.x; + const unsigned int by = blockIdx.y; -// __shared__ src_T tile[rs*blk_c]; + __shared__ src_T tile[rs*blk_c]; -// #pragma unroll -// for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ +#pragma unroll + for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ -// const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; -// if(imat >= nmat) -// break; -// #pragma unroll -// for (unsigned int j = 0; j < rs; j++){ -// const unsigned int row = (j * blk_c + tx) % rs; -// const unsigned int col = (j * blk_c + tx) / rs; -// const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; -// unsigned int idx = row * blk_c + col; -// idx = idx ^ ((idx & mask) >> 4); -// if (src_index < ne) { -// tile[idx] = src[src_index]; -// } -// } -// __syncthreads(); -// #pragma unroll -// for (unsigned int j = 0; j < rs; j++){ -// const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; -// if(dst_index < ne){ -// unsigned int idx = j*blk_c + tx; -// idx = idx ^ ((idx & mask) >> 4); -// dst[dst_index] = ggml_cuda_cast(tile[idx]); -// } -// } -// } -// } + const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; + if(imat >= nmat) + break; +#pragma unroll + for (unsigned int j = 0; j < rs; j++){ + const unsigned int row = (j * blk_c + tx) % rs; + const unsigned int col = (j * blk_c + tx) / rs; + const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; + // const unsigned int src_index = imat*n + rs*ne00 + bx * blk_c + j * blk_c + tx; + unsigned int idx = row * blk_c + col; + // idx = idx ^ ((idx & mask) >> 4); + if (src_index < ne) { + tile[idx] = src[src_index]; + } + } + __syncthreads(); +#pragma unroll + for (unsigned int j = 0; j < rs; j++){ + const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; + if(dst_index < ne){ + unsigned int idx = j*blk_c + tx; + // idx = idx ^ ((idx & mask) >> 4); + dst[dst_index] = ggml_cuda_cast(tile[idx]); + } + } + } +} @@ -1338,9 +1339,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ggml_cuda_pool_alloc kernel_f16(ctx.pool(id)); if (ne01 > 1){ kernel_f16.alloc(ne); - dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, - 1) ; + // dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, + // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, + // 1) ; // if (ne01 == 25) { // constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); @@ -1424,10 +1425,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } } } - candidate = -1; if(candidate != -1){ j = candidate; - printf("choosing %d \n", j); if (j == 2) { launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index b7ff38b833..2ebf0e2c14 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -319,7 +319,7 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - // std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(32,12,141,133,3,3), // std::make_tuple(32,6,141,133,3,3), // std::make_tuple(32,12,141,121,3,3), @@ -330,7 +330,7 @@ static std::vector> configs = { // std::make_tuple(256,12,16,16,3,3), //working // std::make_tuple(32,12,16,16,3,3), //not working // std::make_tuple(48,12,16,16,3,3), // not working - std::make_tuple(96,12,16,16,3,3), //not working + // std::make_tuple(96,12,16,16,3,3), //not working // std::make_tuple(64,12,16,16,3,3), //working // std::make_tuple(64,12,141,133,3,3), //working // std::make_tuple(32,12,141,133,3,3), //working From bccd869968ea67f0136ab7294351ce8938cb2101 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 16 Nov 2025 09:31:38 -0500 Subject: [PATCH 103/122] fixed a bug in the special filter transpose NCHW2NHWC; still failing for channel number < 32 --- ggml/src/ggml-cuda/conv2d-implicit.cu | 20 +++++++++++--------- tests/test-backend-ops.cpp | 2 +- tests/test-conv2d.cpp | 23 ++++++++++++----------- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 15cb613dc6..2bd8bcd281 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -93,6 +93,8 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co const unsigned int bx = blockIdx.x; const unsigned int by = blockIdx.y; + const unsigned int blk = (bx+1) * blk_c <= ne00 ? blk_c : ne00 - bx * blk_c; + __shared__ src_T tile[rs*blk_c]; #pragma unroll @@ -103,12 +105,12 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co break; #pragma unroll for (unsigned int j = 0; j < rs; j++){ - const unsigned int row = (j * blk_c + tx) % rs; - const unsigned int col = (j * blk_c + tx) / rs; - const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; + const unsigned int row = (j * blk + tx) % rs; + const unsigned int col = (j * blk + tx) / rs; + const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk + tx; // const unsigned int src_index = imat*n + rs*ne00 + bx * blk_c + j * blk_c + tx; unsigned int idx = row * blk_c + col; - // idx = idx ^ ((idx & mask) >> 4); + idx = idx ^ ((idx & mask) >> 4); if (src_index < ne) { tile[idx] = src[src_index]; } @@ -117,9 +119,9 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co #pragma unroll for (unsigned int j = 0; j < rs; j++){ const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; - if(dst_index < ne){ + if(dst_index < ne && tx < blk){ unsigned int idx = j*blk_c + tx; - // idx = idx ^ ((idx & mask) >> 4); + idx = idx ^ ((idx & mask) >> 4); dst[dst_index] = ggml_cuda_cast(tile[idx]); } } @@ -1338,17 +1340,17 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa // ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); ggml_cuda_pool_alloc kernel_f16(ctx.pool(id)); if (ne01 > 1){ - kernel_f16.alloc(ne); + // kernel_f16.alloc(ne); // dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, // 1) ; // if (ne01 == 25) { // constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // }else if (ne01 == 16) { + // } else if (ne01 == 16) { // constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // }else if (ne01 == 9) { + // } else if (ne01 == 9) { // constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); // } else if (ne01 == 8) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b037af506b..3f71cf1cf4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5826,7 +5826,7 @@ static std::vector> make_test_cases_eval() { for (uint32_t s0 : { 1, 3 }) { for (uint32_t p1 : { 2, 5 }) { - for (uint32_t Cin : { 1, 25, 48 }) { + for (uint32_t Cin : { 1, 16, 25, 48 }) { for (uint32_t Cout : { 1, 12 }) { for (uint32_t KH : { 1, 2, 3, 11 }) { for (uint32_t KW : { 1, 2, 3, 11 }) { diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 2ebf0e2c14..3005e17edc 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -319,7 +319,7 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(32,12,141,133,3,3), // std::make_tuple(32,6,141,133,3,3), // std::make_tuple(32,12,141,121,3,3), @@ -328,7 +328,8 @@ static std::vector> configs = { // std::make_tuple(320,9,16,16,3,3), //working // std::make_tuple(320,12,16,16,3,3), //working // std::make_tuple(256,12,16,16,3,3), //working - // std::make_tuple(32,12,16,16,3,3), //not working + // std::make_tuple(32,12,16,16,3,3), //not working + std::make_tuple(16,12,16,16,3,3), //not working // std::make_tuple(48,12,16,16,3,3), // not working // std::make_tuple(96,12,16,16,3,3), //not working // std::make_tuple(64,12,16,16,3,3), //working @@ -749,15 +750,15 @@ int main(void) // int i = 2048; // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From febee580c89d2b2c0248095dd4702a19d404906f Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 16 Nov 2025 11:14:27 -0500 Subject: [PATCH 104/122] fixed anotehr bug in the special filter transpose NCHW2NHWC --- ggml/src/ggml-cuda/conv2d-implicit.cu | 79 +++++++++++++-------------- tests/test-conv2d.cpp | 23 ++++---- 2 files changed, 51 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 2bd8bcd281..5958c3f29e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -82,7 +82,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } - //*** broken, has bugs *** template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -108,10 +107,9 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co const unsigned int row = (j * blk + tx) % rs; const unsigned int col = (j * blk + tx) / rs; const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk + tx; - // const unsigned int src_index = imat*n + rs*ne00 + bx * blk_c + j * blk_c + tx; unsigned int idx = row * blk_c + col; idx = idx ^ ((idx & mask) >> 4); - if (src_index < ne) { + if (src_index < ne && tx < blk) { tile[idx] = src[src_index]; } } @@ -122,7 +120,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co if(dst_index < ne && tx < blk){ unsigned int idx = j*blk_c + tx; idx = idx ^ ((idx & mask) >> 4); - dst[dst_index] = ggml_cuda_cast(tile[idx]); + dst[dst_index] = ggml_cuda_cast(tile[idx]); } } } @@ -1340,46 +1338,47 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa // ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); ggml_cuda_pool_alloc kernel_f16(ctx.pool(id)); if (ne01 > 1){ - // kernel_f16.alloc(ne); - // dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, - // (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, - // 1) ; - // if (ne01 == 25) { - // constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 16) { - // constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 9) { - // constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 8) { - // constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 7) { - // constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 6) { - // constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 5) { - // constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 4) { - // constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 3) { - // constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else if (ne01 == 2) { - // constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); - // NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } else { + kernel_f16.alloc(ne); + + dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, + 1) ; + if (ne01 == 25) { + constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 16) { + constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 9) { + constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 8) { + constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 7) { + constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 6) { + constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 5) { + constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 4) { + constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 3) { + constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 2) { + constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else { dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); - // } + } } const half *X_H = input_f16.get(); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 3005e17edc..11cf757bc4 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -319,7 +319,7 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - // std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(32,12,141,133,3,3), // std::make_tuple(32,6,141,133,3,3), // std::make_tuple(32,12,141,121,3,3), @@ -329,7 +329,8 @@ static std::vector> configs = { // std::make_tuple(320,12,16,16,3,3), //working // std::make_tuple(256,12,16,16,3,3), //working // std::make_tuple(32,12,16,16,3,3), //not working - std::make_tuple(16,12,16,16,3,3), //not working + // std::make_tuple(16,12,16,16,3,3), //not working + // std::make_tuple(32,12,16,16,3,3), //not working // std::make_tuple(48,12,16,16,3,3), // not working // std::make_tuple(96,12,16,16,3,3), //not working // std::make_tuple(64,12,16,16,3,3), //working @@ -750,15 +751,15 @@ int main(void) // int i = 2048; // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } + // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From f2187bbfa2c62457f42998eca676898c4f93586a Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 16 Nov 2025 12:05:23 -0500 Subject: [PATCH 105/122] added a few edge test cases --- tests/test-backend-ops.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3f71cf1cf4..fb0224d480 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5826,7 +5826,7 @@ static std::vector> make_test_cases_eval() { for (uint32_t s0 : { 1, 3 }) { for (uint32_t p1 : { 2, 5 }) { - for (uint32_t Cin : { 1, 16, 25, 48 }) { + for (uint32_t Cin : { 1, 25 }) { for (uint32_t Cout : { 1, 12 }) { for (uint32_t KH : { 1, 2, 3, 11 }) { for (uint32_t KW : { 1, 2, 3, 11 }) { @@ -5848,6 +5848,14 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_conv_2d( { 16, 16, 8, 1}, { 3, 3, 8, 12}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 16, 16, 16, 1}, { 3, 3, 16, 6}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 16, 16, 24, 1}, { 3, 3, 24, 6}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 16, 16, 8, 3}, { 3, 3, 8, 6}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d( { 24, 24, 32, 1 }, { 3, 3, 32, 8}, GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d( { 24, 24, 96, 1 }, { 3, 3, 96, 8}, From f54cd74ed00a6095e6981d2ee6b4e801da6626a2 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 16 Nov 2025 13:08:01 -0500 Subject: [PATCH 106/122] due to cp.async, only support filter size <= 32 --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 5958c3f29e..ce6c2b69d8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1309,7 +1309,7 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, param_t P, cudaStream_t st) { // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0) { + if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r <= 32 && P.s <= 32)) { int id = ggml_cuda_get_device(); From 775e48abb21db533284e0880c0342f293901bd41 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 17 Nov 2025 10:02:28 -0500 Subject: [PATCH 107/122] remove some repeated index computation; various code/comments clean up --- ggml/src/ggml-cuda/conv2d-implicit.cu | 187 ++++---------------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 225 +++++-------------------- 2 files changed, 72 insertions(+), 340 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index ce6c2b69d8..1ffce7a9d7 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -786,14 +786,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; - // const unsigned int K = param.c; - // const uint inChannelOffset = param.c * param.w; - // const uint weightKOffset = param.c * param.r * param.s; - - // const unsigned int PQ = param.Ow * param.Oh; - // const unsigned int KPQ = param.k * PQ; - // const unsigned int NKPQ = param.n * KPQ; - // loop bounds, constexpr where possible allows for loop unrolling #if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN constexpr unsigned int mma_tiles_per_warp_k = 2; @@ -817,6 +809,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, unsigned int masks_a[A_K_STRID][2]; int64_t element_offset_a[A_K_STRID]; + int64_t element_offset_b; // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -833,7 +826,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, half* B_block_smem = &shmem[BM * BK]; constexpr int BUFFER_SIZE = BM * BK + BK * BN; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#ifdef CP_ASYNC_AVAILABLE half* SA1 = A_block_smem; half* SB1 = B_block_smem; half* SA2 = &shmem[BUFFER_SIZE]; @@ -841,6 +834,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #else float4 A_gmem_cache_reg[4]; float4 B_gmem_cache_reg[4]; + int offset_direction = 1; #endif // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together @@ -883,21 +877,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, prepareIteratorA(thread_row, masks_a, element_offset_a, param); - // for(int kk =0; kk < A_K_STRID; kk++){ - // if(element_offset_a[kk] >= 327680) - // printf("%d, %d, %d, %d, %d, %lld \n", - // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, - // element_offset_a[kk]); - // } - - // if(threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // printf("A["); - // for(int kk =0; kk < A_K_STRID; kk++) - // printf("%f,", element_offset_a[kk]); - // printf("]\n"); - // } - - // prefetch the first block tile of A,B into shared memory const half* A_block_gmem = input; @@ -905,17 +884,19 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, unsigned int curC = tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, thread_row, thread_col, start_k, end_k, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, thread_row, thread_col, param); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + element_offset_b = curC; + tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, curC, element_offset_b, start_k, end_k, thread_row, thread_col, param); + +#ifdef CP_ASYNC_AVAILABLE asm volatile("cp.async.commit_group;\n" ::); #endif - int offset_direction = 1; + unsigned int block_k = 0; unsigned int block_krs = 1; - // for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){ int s = 0; int r = 0; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#ifdef CP_ASYNC_AVAILABLE while (block_krs < num_block_tiles_krs) { asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); @@ -944,44 +925,26 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, ++block_k; } - // if(threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // printf("B %d,%d,%d [", s, r, block_k); - // for(int kk =0; kk < A_K_STRID; kk++){ - // if(element_offset_a[kk] >= 327680) - // printf("%d, %d, %d, %d, %d, %lld, %d, %d, %d %d, %lld\n", - // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, - // element_offset_a[kk], r, s, block_k, next_idx, param.inc_next[next_idx]); - // } - // threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // printf("%f,", element_offset_a[kk]); - // printf("]\n"); - // if(block_k == num_block_tiles_k) - // break; - - // if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){ - // printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d, %d, %d \n", s, r, block_k, next_idx, - // block_krs, num_block_tiles_k, num_block_tiles_krs); - // } - - // if (block_k != num_block_tiles_k){ - if (block_krs != num_block_tiles_krs){ -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + if (block_krs != num_block_tiles_krs) { +#ifdef CP_ASYNC_AVAILABLE curC = tileMemcpyAsyncLoadA(A_block_gmem, SA2, r, s, masks_a, element_offset_a, thread_row, thread_col, block_k * BK, start_k, end_k, curC, param); - tileMemcpyAsyncLoadB(B_block_gmem, SB2, r, s, block_k * BK, + element_offset_b = (r*param.s+s)*param.c + curC; + tileMemcpyAsyncLoadB(B_block_gmem, SB2, r, s, curC, element_offset_b, block_k * BK, start_k, end_k, thread_row, thread_col, param); asm volatile("cp.async.commit_group;\n" ::); #else curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, masks_a, element_offset_a, thread_row, thread_col, block_k * BK, start_k, end_k, curC, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, + element_offset_b = (r*param.s+s)*param.c + curC; + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, curC, element_offset_b, block_k * BK, start_k, end_k, thread_row, thread_col, param); #endif } -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#ifdef CP_ASYNC_AVAILABLE half* A_warp_tile = SA1 + A_warp_tile_offset; half* B_warp_tile = SB1 + B_warp_tile_offset; #else @@ -994,11 +957,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // outer product between mma tiles #pragma unroll - for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++){ + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) { #pragma unroll - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { #pragma unroll - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " @@ -1026,49 +989,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #endif } } - - // if(threadIdx.x >= 8 && threadIdx.x < 12 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("A %d, %d, %d: %f, %f \n", block_krs, mma_k, threadIdx.x, - // __half2float(A_register_[1][mma_k][0]), - // __half2float(A_register_[1][mma_k][1])); - // } - // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("B %d, %d, %d: %f, %f\n", block_krs, mma_k, threadIdx.x, - // __half2float(B_register_[mma_k][1][0]), - // __half2float(B_register_[mma_k][1][1])); - // } - // if(threadIdx.x == 8 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("C %d, %d, %d: %f, %f, %f, %f\n", block_krs, mma_k, threadIdx.x, - // __half2float(acc_register_[1][1][0]), - // __half2float(acc_register_[1][1][1]), - // __half2float(acc_register_[1][1][2]), - // __half2float(acc_register_[1][1][3])); - // } - - // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("A %d, %d, (%d, %d) %d: %f, %f \n", block_krs, mma_k, r, s, threadIdx.x, - // __half2float(A_register_[0][mma_k][0]), - // __half2float(A_register_[0][mma_k][1])); - // } - // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("B %d, %d, (%d, %d) %d: %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, - // __half2float(B_register_[mma_k][0][0]), - // __half2float(B_register_[mma_k][0][1])); - // } - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("C %d, %d, (%d, %d) %d: %f, %f, %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, - // __half2float(acc_register_[0][0][0]), - // __half2float(acc_register_[0][0][1]), - // __half2float(acc_register_[0][0][2]), - // __half2float(acc_register_[0][0][3])); - // } - } - // if (block_k != num_block_tiles_k) if (block_krs != num_block_tiles_krs) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#ifdef CP_ASYNC_AVAILABLE half *tmp = SA1; SA1 = SA2; SA2 = tmp; tmp = SB1; SB1 = SB2; SB2 = tmp; #else @@ -1085,7 +1010,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#ifdef CP_ASYNC_AVAILABLE asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); __syncthreads(); half* A_warp_tile = SA1 + A_warp_tile_offset; @@ -1094,11 +1019,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, ldmatrix_b(B_warp_tile, B_register_); // outer product between mma tiles #pragma unroll - for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++){ + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) { #pragma unroll - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { #pragma unroll - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " @@ -1126,42 +1051,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, #endif } } - // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("A %d, %d, (%d, %d) %d: %f, %f \n", block_krs, mma_k, r, s, threadIdx.x, - // __half2float(A_register_[0][mma_k][0]), - // __half2float(A_register_[0][mma_k][1])); - // } - // if(threadIdx.x < 4 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("B %d, %d, (%d, %d) %d: %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, - // __half2float(B_register_[mma_k][0][0]), - // __half2float(B_register_[mma_k][0][1])); - // } - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf("C %d, %d, (%d, %d) %d: %f, %f, %f, %f\n", block_krs, mma_k, r, s, threadIdx.x, - // __half2float(acc_register_[0][0][0]), - // __half2float(acc_register_[0][0][1]), - // __half2float(acc_register_[0][0][2]), - // __half2float(acc_register_[0][0][3])); - // } } #endif - // if(threadIdx.x == 8 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf(" %u, %f, %f, %f, %f\n", blockIdx.z, - // __half2float(acc_register_[1][1][0]), - // __half2float(acc_register_[1][1][1]), - // __half2float(acc_register_[1][1][2]), - // __half2float(acc_register_[1][1][3])); - // } - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){ - // printf(" %u, %f, %f, %f, %f\n", blockIdx.z, - // __half2float(acc_register_[0][1][0]), - // __half2float(acc_register_[0][1][1]), - // __half2float(acc_register_[0][1][2]), - // __half2float(acc_register_[0][1][3])); - // } - // reuse smem half *smemoutput = shmem; const uint lane_id = threadIdx.x % WARPSIZE; @@ -1174,16 +1067,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint n_idx = block_m * BM + warp_m * WM + lane_id; #pragma unroll - for (int i = 0; i < 2; ++i) - { + for (int i = 0; i < 2; ++i) { const unsigned int i_offset = i * mma_tiles_per_warp_n/2; __syncthreads(); #pragma unroll - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) - { + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { const unsigned int mma_m_offset = output_sts_addr + mma_m * MMA_M * BN / 2; - for (unsigned int mma_n = i_offset; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) - { + for (unsigned int mma_n = i_offset; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); uint idx = mma_m_offset + (mma_n - i_offset) * MMA_N; idx = idx ^ ((idx & 0b110000000000) >> 9); @@ -1199,13 +1089,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const unsigned int m_i_wn = m_idx + i * WN / 2; #pragma unroll - for (int subk = 0; subk < WN / 4; ++subk){ + for (int subk = 0; subk < WN / 4; ++subk) { const uint row = m_i_wn + subk*2; uint idx = output_lds_addr + subk*2; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); #pragma unroll - for (int j = 0; j < 4; ++j){ + for (int j = 0; j < 4; ++j) { const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); @@ -1213,14 +1103,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < param.PQ) { const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + row * param.PQ + col; - // if(row == 8 && col == 18) - // printf("A %u, %u, %f \n", outOffset, z, ggml_cuda_cast(res_[0])); output[outOffset] = ggml_cuda_cast(res_[0]); } if (n < param.n && row+1 < param.k && col < param.PQ) { const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + (row+1) * param.PQ + col; - // if(row+1 == 8 && col == 17) - // printf("B %u, %u, %f \n", outOffset, z, ggml_cuda_cast(res_[0])); output[outOffset] = ggml_cuda_cast(res_[1]); } } @@ -1532,13 +1418,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint PD_Y = p[3]; // padding_y const uint DL_X = p[4]; // dilation_x const uint DL_Y = p[5]; // dilation_y - // const int LT = p[6]; // layout - // GGML_ASSERT(LT == 0 || LT == 1); - - // same number of input channels - // GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]); - // No cwhn GGML_ASSERT(p[6] == false); const uint IW = input->ne[0]; // input_w @@ -1554,13 +1434,6 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * int64_t pp[3] = {0}; - // const unsigned int K = param.c; -// const uint inChannelOffset = param.c * param.w; -// const uint weightKOffset = param.c * param.r * param.s; -// const unsigned int PQ = param.Ow * param.Oh; -// const unsigned int KPQ = param.k * PQ; -// const unsigned int NKPQ = param.n * KPQ; - param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW, init_fastdiv_values(KW*IC), diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index ee56c80b7f..2cf03f268f 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -1,6 +1,11 @@ #pragma once #include "common.cuh" +constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; +constexpr unsigned int SWIZZLE_BITS_1 = 4; +constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; +constexpr unsigned int SWIZZLE_BITS_2 = 2; + typedef struct{ unsigned int n; //batch size unsigned int c; //number if channels @@ -24,7 +29,6 @@ typedef struct{ uint3 S_fastdiv; uint3 OHOW_fastdiv; int64_t inc_next[3]; - // unsigned int K; unsigned int inChannelOffset; unsigned int weightKOffset; unsigned int PQ; @@ -37,7 +41,6 @@ typedef struct{ /// Clears the predicates template -// __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { #pragma unroll @@ -48,8 +51,7 @@ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { } template -// __host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ -__device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ +__device__ void add_byte_offset(int64_t element_offset[], const int64_t offset) { #pragma unroll for (int s = 0; s < K_STRID; ++s) { element_offset[s] += offset; @@ -63,21 +65,14 @@ template(ptr); - - // int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; const unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; offset_n[s] = fastdiv(gemm_i, param.OHOW_fastdiv); unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); @@ -86,19 +81,8 @@ __device__ void prepareIteratorA(unsigned int thread_row, const int h = offset_p[s] * (int)param.u - (int) param.p; const int w = offset_q[s] * (int)param.v - (int) param.q; - // if(threadIdx.x < 32 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) - // printf("%d, %d : %d, %d, %d, %d offset (%d, %d, %d), kele %llu Kcont %d\n ", thread_idx, s, - // // printf("[%s - %d] %d, %d : %d, %d, %d, %d\n ", __FUNCTION__, __LINE__, thread_idx, s, - // threadblock_offset.row(), thread_coord.strided(), ThreadMap::Delta::kStrided, - // offset_npq, offset_n[s], offset_p[s], offset_q[s], AccessType::kElements, - // ThreadMap::Iterations::kContiguous); - element_offset[s] = offset_n[s] * (int64_t)param.CHW + h * (int64_t)(param.inChannelOffset) + w * (int64_t)param.c; - // if(element_offset[s] >= 327680) - // printf("(%d, %d, %d, %d, %d), %d, %lld, %d, %d, %d, %d, %d, %u, %u, %u \n", - // threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z, - // s, element_offset[s], offset_n[s], offset_p[s], offset_q[s], h, w, chw, param.c * param.w, param.c); thread_row += ROW_STEP; } @@ -126,8 +110,7 @@ __device__ void prepareIteratorA(unsigned int thread_row, template __device__ void cp_async_zfill(void *ptr, void const *global_ptr, bool pred_guard = true) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - +#ifdef CP_ASYNC_AVAILABLE unsigned int smem_ptr; int src_in_bytes = pred_guard ? preload : 0; @@ -154,19 +137,16 @@ __device__ __forceinline__ void tileMemcpySwizzleB( half* __restrict__ dst, const unsigned int curR, const unsigned int curS, + const unsigned int curC, + const int64_t ki, const unsigned int start_k, const unsigned int end_k, unsigned int thread_row, const unsigned int thread_col, - // const unsigned int src_stride, param_t param -){ +) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING - constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; - constexpr unsigned int SWIZZLE_BITS_1 = 4; - constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; - constexpr unsigned int SWIZZLE_BITS_2 = 2; constexpr unsigned int TILE_COLS = 32; float4* dst_float4 = reinterpret_cast(dst); @@ -174,39 +154,27 @@ __device__ __forceinline__ void tileMemcpySwizzleB( // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x - // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - - // const unsigned int ki = (curR*param.s+curS)*param.c + start_k+thread_col*8; - // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // - const unsigned int curC = start_k+thread_col*8; - const unsigned int ki = (curR*param.s+curS)*param.c + curC; #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++){ + for (unsigned int i = 0; i < NUM_ITERS; i++) { // apply swizzle to the dst index const unsigned int src_index = thread_row * param.weightKOffset + ki; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - +#ifdef CP_ASYNC_AVAILABLE cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]), thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k); #else - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k) { dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; - }else{ // read 4 halves + } else { // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } #endif @@ -217,6 +185,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( GGML_UNUSED(dst); GGML_UNUSED(curR); GGML_UNUSED(curS); + GGML_UNUSED(ki); GGML_UNUSED(start_k); GGML_UNUSED(end_k); GGML_UNUSED(thread_row); @@ -242,14 +211,9 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( const unsigned int start_k, const unsigned int end_k, param_t param -) -{ +) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING - constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; - constexpr unsigned int SWIZZLE_BITS_1 = 4; - constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; - constexpr unsigned int SWIZZLE_BITS_2 = 2; constexpr unsigned int TILE_COLS = 32; float4* dst_float4 = reinterpret_cast(dst); @@ -257,42 +221,26 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x - // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - // const unsigned int ki = start_k+thread_col*8; - // const unsigned int chw = param.c * param.h * param.w; - // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = start_k+thread_col*8; clear_mask(masks, curC >= end_k); #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++){ + for (unsigned int i = 0; i < NUM_ITERS; i++) { bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); // apply swizzle to the dst index unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - // if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ - // printf(" %u, %u, %u, %u, %lld, %d\n", i, curR, curS, curC, element_offset[i], valid?1:0); - // } - // if (valid && curC < end_k){ -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#ifdef CP_ASYNC_AVAILABLE cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid); #else - if (valid){ - // if(element_offset[i] >= 327680 || element_offset[i] < 0) - // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, - // i, element_offset[i], curR, curS, curC); + if (valid) { dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; } else { dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -300,29 +248,6 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA( #endif thread_row += ROW_STEP; } - // #pragma unroll - // for (unsigned int i = 0; i < NUM_ITERS; i++){ - // unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - // // unsigned int inOffset = n * param.c * param.h * param.w; - // int curH = posh_ori + curR * param.d_h; // input h - // int curW = posw_ori + curS * param.d_w; // input w - // // apply swizzle to the dst index - // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; - // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); - // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ - // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - // dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; - // } else{ - // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); - // } - // thread_row += ROW_STEP; - // } return curC; #else GGML_UNUSED(src); @@ -357,42 +282,29 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA( const unsigned int start_k, const unsigned int end_k, unsigned int oldC, - // const unsigned int inChannelOffset, param_t param -){ +) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - // const unsigned int ki = start_k+block_k+thread_col*8; - // const unsigned int chw = param.c * param.h * param.w; - - // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = start_k+block_k+thread_col*8; if (curC > oldC) clear_mask(masks, curC >= end_k); #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++){ + for (unsigned int i = 0; i < NUM_ITERS; i++) { bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); - // if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ - // printf(" %u, %u, %u, %u, %u, %lld, %d\n", i, curR, curS, oldC, curC, element_offset[i], valid?1:0); - // } if (valid) { dst_reg[i] = reinterpret_cast(&src[element_offset[i]+curC])[0]; } else{ @@ -435,50 +347,32 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( const unsigned int start_k, const unsigned int end_k, unsigned int oldC, - // const unsigned int inChannelOffset, param_t param -){ -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE - // # of threads is multiple of # of columns in the tile - constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; - constexpr unsigned int SWIZZLE_BITS_1 = 4; - constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; - constexpr unsigned int SWIZZLE_BITS_2 = 2; +) { +#ifdef CP_ASYNC_AVAILABLE constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); float4* dst_float4 = reinterpret_cast(dst); - // flatten out 2d grid of threads into in order of increasing threadIdx.x // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - // const unsigned int ki = start_k+block_k+thread_col*8; - // const unsigned int chw = param.c * param.h * param.w; - - // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset const unsigned int curC = start_k+block_k+thread_col*8; if (curC > oldC) clear_mask(masks, curC >= end_k); unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++){ - bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); - // if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){ - // printf(" %u, %u, %u, %u, %u, %lld, %d\n", i, curR, curS, oldC, curC, element_offset[i], valid?1:0); - // } + for (unsigned int i = 0; i < NUM_ITERS; i++) { + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); unsigned int dst_index = iter_idx; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); @@ -515,57 +409,40 @@ __device__ __forceinline__ void tileMemcpyLoadB( float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int curR, const unsigned int curS, + const unsigned int curC, + const int64_t ki, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, unsigned int thread_row, const unsigned int thread_col, - // const unsigned int src_stride, param_t param -){ +) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING - - constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; - constexpr unsigned int SWIZZLE_BITS_1 = 4; - constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; - constexpr unsigned int SWIZZLE_BITS_2 = 2; - // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x - // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // - const unsigned int curC = start_k+block_k+thread_col*8; - const unsigned int ki = (curR*param.s+curS)*param.c + curC; - unsigned int iter_idx = thread_row * param.weightKOffset + ki; unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; const int ITER_STEPS = ROW_STEP * param.weightKOffset; #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++){ - // const unsigned int src_index = thread_row * param.weightKOffset + ki; + for (unsigned int i = 0; i < NUM_ITERS; i++) { const unsigned int src_index = iter_idx; - // if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ - if (krow_idx < param.k && curC < end_k){ + if (krow_idx < param.k && curC < end_k) { dst_reg[i] = reinterpret_cast(&src[src_index])[0]; - }else{ // read 4 halves + } else { // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } krow_idx += ROW_STEP; @@ -577,6 +454,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( GGML_UNUSED(block_k); GGML_UNUSED(curR); GGML_UNUSED(curS); + GGML_UNUSED(ki); GGML_UNUSED(start_k); GGML_UNUSED(end_k); GGML_UNUSED(thread_row); @@ -595,27 +473,22 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( half *dst, const unsigned int curR, const unsigned int curS, + const unsigned int curC, + const int64_t ki, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, unsigned int thread_row, const unsigned int thread_col, param_t param -){ +) { -#if __CUDA_ARCH__ >= GGML_CUDA_AMPERE - - constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; - constexpr unsigned int SWIZZLE_BITS_1 = 4; - constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; - constexpr unsigned int SWIZZLE_BITS_2 = 2; +#ifdef CP_ASYNC_AVAILABLE // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x - // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; float4* dst_float4 = reinterpret_cast(dst); // assign each thread a row/column in the tile, calculate how many iterations we need @@ -627,17 +500,13 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int curC = start_k+block_k+thread_col*8; - const unsigned int ki = (curR*param.s+curS)*param.c + curC; - unsigned int iter_src_idx = thread_row * param.weightKOffset + ki; unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++){ - // const unsigned int src_index = thread_row * param.weightKOffset + ki; + for (unsigned int i = 0; i < NUM_ITERS; i++) { const unsigned int src_index = iter_src_idx; unsigned int dst_index = iter_dst_idx; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); @@ -655,6 +524,7 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( GGML_UNUSED(block_k); GGML_UNUSED(curR); GGML_UNUSED(curS); + GGML_UNUSED(ki); GGML_UNUSED(start_k); GGML_UNUSED(end_k); GGML_UNUSED(thread_row); @@ -676,14 +546,10 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( half* __restrict__ dst, unsigned int thread_row, const unsigned int thread_col -) -{ +) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING - constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; - constexpr unsigned int SWIZZLE_BITS_1 = 4; - constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; - constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; // reinterpret input/output as float4 @@ -693,26 +559,19 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x - // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; - // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; - // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; - // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; #pragma unroll - for (unsigned int i = 0; i < NUM_ITERS; i++) - { + for (unsigned int i = 0; i < NUM_ITERS; i++) { // apply swizzle to the dst index - // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; unsigned int dst_index = iter_idx; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); From 3e691046dc9c6d270c924711f158f7110f3c5f7d Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 17 Nov 2025 10:41:30 -0500 Subject: [PATCH 108/122] minor tweak filter tranpose --- ggml/src/ggml-cuda/conv2d-implicit.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1ffce7a9d7..8b9c83ba1a 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -14,7 +14,7 @@ constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 #define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 -#define CUDA_NCHW_2_NHWC_BLOCK_C 64 +#define CUDA_NCHW_2_NHWC_BLOCK_C 32 //currently not use; in future for split-k kernels @@ -58,12 +58,13 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; __shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM]; - +#pragma unroll for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ const unsigned int imat = blockIdx.z * CUDA_NCHW_2_NHWC_BLOCK_NM + i; if(imat >= nmat) break; +#pragma unroll for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; @@ -72,7 +73,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } __syncthreads(); - +#pragma unroll for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ if(ty + j < ne01 && tx < ne00){ const int col = (threadIdx.y+j) ^ threadIdx.x; From 9bb5eb30e5a1b8b33d623aa7cd5534c0f52e9dde Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 17 Nov 2025 11:45:01 -0500 Subject: [PATCH 109/122] tuned block dimensions for filter tranpose --- ggml/src/ggml-cuda/conv2d-implicit.cu | 46 +++++++++++---------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 8b9c83ba1a..be993d99c1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -14,7 +14,7 @@ constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 #define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 -#define CUDA_NCHW_2_NHWC_BLOCK_C 32 +#define CUDA_NCHW_2_NHWC_BLOCK_C 64 //currently not use; in future for split-k kernels @@ -86,7 +86,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ - const int64_t nmat = ne / (ne00 * ne01); const int64_t n = ne00 * ne01; const unsigned int tx = threadIdx.x; @@ -97,32 +96,26 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co __shared__ src_T tile[rs*blk_c]; -#pragma unroll - for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ - const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; - if(imat >= nmat) - break; #pragma unroll - for (unsigned int j = 0; j < rs; j++){ - const unsigned int row = (j * blk + tx) % rs; - const unsigned int col = (j * blk + tx) / rs; - const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk + tx; - unsigned int idx = row * blk_c + col; - idx = idx ^ ((idx & mask) >> 4); - if (src_index < ne && tx < blk) { - tile[idx] = src[src_index]; - } + for (unsigned int j = 0; j < rs; j++){ + const unsigned int row = (j * blk + tx) % rs; + const unsigned int col = (j * blk + tx) / rs; + const unsigned int src_index = by*n + bx * blk_c * rs + j * blk + tx; + unsigned int idx = row * blk_c + col; + idx = idx ^ ((idx & mask) >> 4); + if (src_index < ne && tx < blk) { + tile[idx] = src[src_index]; } - __syncthreads(); + } + __syncthreads(); #pragma unroll - for (unsigned int j = 0; j < rs; j++){ - const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; - if(dst_index < ne && tx < blk){ - unsigned int idx = j*blk_c + tx; - idx = idx ^ ((idx & mask) >> 4); - dst[dst_index] = ggml_cuda_cast(tile[idx]); - } + for (unsigned int j = 0; j < rs; j++){ + const unsigned int dst_index = by*n + j*ne00 + bx*blk_c + tx; + if(dst_index < ne && tx < blk){ + unsigned int idx = j*blk_c + tx; + idx = idx ^ ((idx & mask) >> 4); + dst[dst_index] = ggml_cuda_cast(tile[idx]); } } } @@ -1222,14 +1215,13 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ne = P.c * P.r * P.s * P.k; ne01 = P.r * P.s; - // ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); ggml_cuda_pool_alloc kernel_f16(ctx.pool(id)); if (ne01 > 1){ kernel_f16.alloc(ne); dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, - 1) ; + ne/(ne00*ne01), + 1) ; if (ne01 == 25) { constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); From 5fbdefdb9da4f56441f0bde50eb20bd1339a79d6 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 17 Nov 2025 12:01:59 -0500 Subject: [PATCH 110/122] use fastdiv in filter transpose --- ggml/src/ggml-cuda/conv2d-implicit.cu | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index be993d99c1..8361c422ad 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -84,7 +84,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } template -static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01, param_t P){ const int64_t n = ne00 * ne01; @@ -99,8 +99,9 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co #pragma unroll for (unsigned int j = 0; j < rs; j++){ - const unsigned int row = (j * blk + tx) % rs; - const unsigned int col = (j * blk + tx) / rs; + const int i = j * blk + tx; + const unsigned int row = fastmodulo(i, P.RS_fastdiv); + const unsigned int col = fastdiv(i, P.RS_fastdiv); const unsigned int src_index = by*n + bx * blk_c * rs + j * blk + tx; unsigned int idx = row * blk_c + col; idx = idx ^ ((idx & mask) >> 4); @@ -1224,34 +1225,34 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa 1) ; if (ne01 == 25) { constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 16) { constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 9) { constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 8) { constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 7) { constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 6) { constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 5) { constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 4) { constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 3) { constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else if (ne01 == 2) { constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); } else { dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, From 5e491258f97b566db041a4cb7bcab3d2eac2bafc Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 17 Nov 2025 14:34:24 -0500 Subject: [PATCH 111/122] make CI happy --- ggml/src/ggml-cuda/conv2d-implicit.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 8361c422ad..1a80901409 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1285,9 +1285,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + const unsigned int nsm = (unsigned int) (ggml_cuda_info().devices[ggml_cuda_get_device()].nsm); // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { - if (BlocksM * BlocksN < 2*(unsigned int)nsm){ + if (BlocksM * BlocksN < 2*nsm){ int j, max_remaining_waves = -1, candidate = -1; int ks = min(20, nsm / (BlocksM * BlocksN)); if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) From ba754ce4f3c1f486779d2918b27f16ab21e420f0 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 17 Nov 2025 16:51:01 -0500 Subject: [PATCH 112/122] remove trailing blanks --- ggml/src/ggml-cuda/conv2d-implicit.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 2cf03f268f..409c050c89 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -549,7 +549,6 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( ) { #if __CUDA_ARCH__ >= GGML_CUDA_TURING - constexpr unsigned int TILE_COLS = 32; // reinterpret input/output as float4 From 73444564e632a28553e6b972b288d2b05e0eee7b Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 18 Nov 2025 18:36:45 -0500 Subject: [PATCH 113/122] further reduce repeated index comutations --- ggml/src/ggml-cuda/conv2d-implicit.cu | 11 +++++++++-- ggml/src/ggml-cuda/conv2d-implicit.cuh | 16 +++++++++++----- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1a80901409..917f3a6b1e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -871,6 +871,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, prepareIteratorA(thread_row, masks_a, element_offset_a, param); + unsigned int iter_src_idx = thread_row * param.weightKOffset; + unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; + unsigned int krow_idx = thread_row + blockIdx.x * BN; + const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; + // prefetch the first block tile of A,B into shared memory @@ -923,11 +928,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_krs != num_block_tiles_krs) { #ifdef CP_ASYNC_AVAILABLE curC = tileMemcpyAsyncLoadA(A_block_gmem, SA2, r, s, - masks_a, element_offset_a, thread_row, thread_col, block_k * BK, + masks_a, element_offset_a, thread_row, thread_col, + iter_dst_idx, block_k * BK, start_k, end_k, curC, param); element_offset_b = (r*param.s+s)*param.c + curC; tileMemcpyAsyncLoadB(B_block_gmem, SB2, r, s, curC, element_offset_b, block_k * BK, - start_k, end_k, thread_row, thread_col, param); + start_k, end_k, thread_row, thread_col, + iter_src_idx, iter_dst_idx, krow_idx, ITER_SRC_STEPS,param); asm volatile("cp.async.commit_group;\n" ::); #else curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 409c050c89..6df8478a47 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -343,6 +343,7 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( const int64_t element_offset[], unsigned int thread_row, const unsigned int thread_col, + unsigned int iter_idx, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -369,7 +370,6 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( if (curC > oldC) clear_mask(masks, curC >= end_k); - unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); @@ -393,6 +393,7 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( GGML_UNUSED(element_offset); GGML_UNUSED(thread_row); GGML_UNUSED(thread_col); + GGML_UNUSED(iter_idx); GGML_UNUSED(oldC); GGML_UNUSED(param); NO_DEVICE_CODE; @@ -480,6 +481,10 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( const unsigned int end_k, unsigned int thread_row, const unsigned int thread_col, + unsigned int iter_src_idx, + unsigned int iter_dst_idx, + unsigned int krow_idx, + const int ITER_SRC_STEPS, param_t param ) { @@ -500,10 +505,7 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - unsigned int iter_src_idx = thread_row * param.weightKOffset + ki; - unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; - unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; - const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; + iter_src_idx += ki; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { @@ -529,6 +531,10 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB( GGML_UNUSED(end_k); GGML_UNUSED(thread_row); GGML_UNUSED(thread_col); + GGML_UNUSED(iter_src_idx); + GGML_UNUSED(iter_dst_idx); + GGML_UNUSED(krow_idx); + GGML_UNUSED(ITER_SRC_STEPS); GGML_UNUSED(param); NO_DEVICE_CODE; #endif From e760cd49bd7415f18acc29c1309e35d8756c0a15 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 19 Nov 2025 07:47:11 -0500 Subject: [PATCH 114/122] fix CI --- ggml/src/ggml-cuda/conv2d-implicit.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 917f3a6b1e..602bc37a0e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -871,10 +871,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, prepareIteratorA(thread_row, masks_a, element_offset_a, param); + +#ifdef CP_ASYNC_AVAILABLE unsigned int iter_src_idx = thread_row * param.weightKOffset; unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; unsigned int krow_idx = thread_row + blockIdx.x * BN; const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; +#endif // prefetch the first block tile of A,B into shared memory From 0c571feee191cea725811388044ca2a99c552846 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 27 Jan 2026 10:02:15 -0500 Subject: [PATCH 115/122] use existing cvta_generic_to_shared --- ggml/src/ggml-cuda/conv2d-implicit.cu | 5 +++-- ggml/src/ggml-cuda/conv2d-implicit.cuh | 12 ------------ ggml/src/ggml-cuda/cp-async.cuh | 2 +- tests/test-backend-ops.cpp | 1 + 4 files changed, 5 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 602bc37a0e..80a406e2c5 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -3,6 +3,7 @@ #include "ggml.h" #include "common.cuh" #include "convert.cuh" +#include "cp-async.cuh" #include "conv2d-implicit.cuh" @@ -365,7 +366,7 @@ __device__ __forceinline__ void ldmatrix_a( unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); - uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + uint32_t src_addr = ggml_cuda_cvta_generic_to_shared(src + swizzled_offset); constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 @@ -633,7 +634,7 @@ __device__ __forceinline__ void ldmatrix_b( unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); - uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + uint32_t src_addr = ggml_cuda_cvta_generic_to_shared(src + swizzled_offset); constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 6df8478a47..aeaa158d72 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -592,18 +592,6 @@ __device__ __forceinline__ void tileMemcpySwizzleStore( #endif } -__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { - uint32_t address; - asm("{\n\t" - " .reg .u64 u64addr;\n\t" - " cvta.to.shared.u64 u64addr, %1;\n\t" - " cvt.u32.u64 %0, u64addr;\n\t" - "}" - : "=r"(address) - : "l"(pointer)); - return address; -} - template __device__ __forceinline__ void loadFilter(const T * __restrict__ kernel, diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index 63d0c482ff..91011234b2 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -3,7 +3,7 @@ #include "common.cuh" -static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(const void * generic_ptr) { #ifdef CP_ASYNC_AVAILABLE return __cvta_generic_to_shared(generic_ptr); #else diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7664e9aec9..c254461215 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #ifdef __EMSCRIPTEN__ # define N_THREADS 1 From 05d9a9132a5af1eac65fc7ed0e1fdad9c8f3b800 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 27 Jan 2026 10:12:58 -0500 Subject: [PATCH 116/122] remove m16n8k16 path which is not faster than m16n8k8 --- ggml/src/ggml-cuda/conv2d-implicit.cu | 235 +------------------------- 1 file changed, 8 insertions(+), 227 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 80a406e2c5..e73ec150df 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -9,8 +9,6 @@ typedef unsigned int uint; -#define GGML_CUDA_CC_RUBIN 10000 - constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 @@ -344,25 +342,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] -#else half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] -#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); -#else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); -#endif -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); -#else uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); -#endif + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -404,39 +391,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; // 1 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0][2]), "=r"(reg_[0][0][3]), "=r"(reg_[1][0][2]), "=r"(reg_[1][0][3]) - : "r"(src_addr) - ); - // 1 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][0][2]), "=r"(reg_[2][0][3]), "=r"(reg_[3][0][2]), "=r"(reg_[3][0][3]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 1 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][0][2]), "=r"(reg_[4][0][3]), "=r"(reg_[5][0][2]), "=r"(reg_[5][0][3]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 1 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][0][2]), "=r"(reg_[6][0][3]), "=r"(reg_[7][0][2]), "=r"(reg_[7][0][3]) - : "r"(src_addr + 96 * smem_stride_) - ); - -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -467,43 +422,10 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) : "r"(src_addr + 96 * smem_stride_) ); -#endif src_addr ^= 0b110000; // 2 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) - : "r"(src_addr) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) - : "r"(src_addr + 96 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -534,42 +456,10 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) : "r"(src_addr + 96 * smem_stride_) ); -#endif src_addr ^= 0b10000; // 3 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][1][2]), "=r"(reg_[0][1][3]), "=r"(reg_[1][1][2]), "=r"(reg_[1][1][3]) - : "r"(src_addr) - ); - // 3 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][1][2]), "=r"(reg_[2][1][3]), "=r"(reg_[3][1][2]), "=r"(reg_[3][1][3]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 3 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][1][2]), "=r"(reg_[4][1][3]), "=r"(reg_[5][1][2]), "=r"(reg_[5][1][3]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 3 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3]) - : "r"(src_addr + 96 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -600,7 +490,7 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) : "r"(src_addr + 96 * smem_stride_) ); -#endif + #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -611,26 +501,14 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] -#else half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] -#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); -#else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); -#endif static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - uint32_t (®_) [2][8][2] = reinterpret_cast(reg); -#else uint32_t (®_) [4][8] = reinterpret_cast(reg); -#endif + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -638,21 +516,7 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0][0]), "=r"(reg_[0][1][0]), "=r"(reg_[0][2][0]), "=r"(reg_[0][3][0]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -667,25 +531,10 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0][1]), "=r"(reg_[0][1][1]), "=r"(reg_[0][2][1]), "=r"(reg_[0][3][1]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -699,25 +548,10 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif src_addr ^= 0b110000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][0][0]), "=r"(reg_[1][1][0]), "=r"(reg_[1][2][0]), "=r"(reg_[1][3][0]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -731,25 +565,11 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif + src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][0][1]), "=r"(reg_[1][1][1]), "=r"(reg_[1][2][1]), "=r"(reg_[1][3][1]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -763,7 +583,6 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -783,11 +602,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_N = 8; // loop bounds, constexpr where possible allows for loop unrolling -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - constexpr unsigned int mma_tiles_per_warp_k = 2; -#else + constexpr unsigned int mma_tiles_per_warp_k = 4; -#endif constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; @@ -835,23 +651,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; - uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; -#else + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; -#endif // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); - half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); -#else half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); -#endif + // accumulators start at 0 for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ @@ -968,19 +776,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0, %1}, " - "{%2, %3, %4, %5}, " - "{%6, %7}, " - "{%8, %9};" - : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]), - "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) - "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) - ); -#else + asm volatile ( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}, " @@ -992,7 +788,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); -#endif } } } @@ -1030,19 +825,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0, %1}, " - "{%2, %3, %4, %5}, " - "{%6, %7}, " - "{%8, %9};" - : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]), - "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) - "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) - ); -#else asm volatile ( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}, " @@ -1054,7 +836,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); -#endif } } } From 13820ad653d7d1f935061d58c149683f6a49e8cc Mon Sep 17 00:00:00 2001 From: bssrdf Date: Tue, 27 Jan 2026 19:00:45 -0500 Subject: [PATCH 117/122] added a test case that failed in test-backend-ops --- tests/test-conv2d.cpp | 68 +++++++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 11cf757bc4..f4eedb942f 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -43,7 +43,7 @@ struct ggml_cgraph * build_graph_1(const test_model&); void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { // create data int KW = kw, KH = kh, IC = ic, OC = oc; - int IW = iw, IH = ih, N = 1; + int IW = iw, IH = ih, N = 2; // srand(time(NULL)); // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); @@ -178,19 +178,19 @@ struct ggml_cgraph * build_graph_0(const test_model& model) { struct ggml_cgraph * gf = ggml_new_graph(ctx0); - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; + // int s0 = 1; + // int s1 = 1; + // int p0 = 1; + // int p1 = 1; + // int d0 = 1; + // int d1 = 1; - // int s0 = 3; - // int s1 = 5; - // int p0 = 5; - // int p1 = 5; - // int d0 = 2; - // int d1 = 4; + int s0 = 1; + int s1 = 5; + int p0 = 5; + int p1 = 2; + int d0 = 2; + int d1 = 4; // recalculate for avoid fragmentation struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); @@ -224,12 +224,12 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { struct ggml_cgraph * gf = ggml_new_graph(ctx0); - int s0 = 1; - int s1 = 1; - int p0 = 1; - int p1 = 1; - int d0 = 1; - int d1 = 1; + // int s0 = 1; + // int s1 = 1; + // int p0 = 1; + // int p1 = 1; + // int d0 = 1; + // int d1 = 1; // int s0 = 3; @@ -239,6 +239,13 @@ struct ggml_cgraph * build_graph_1(const test_model& model) { // int d0 = 2; // int d1 = 4; + int s0 = 1; + int s1 = 5; + int p0 = 5; + int p1 = 2; + int d0 = 2; + int d1 = 4; + // recalculate for avoid fragmentation // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); @@ -319,7 +326,8 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1,1,1,133,1,1), // std::make_tuple(32,12,141,133,3,3), // std::make_tuple(32,6,141,133,3,3), // std::make_tuple(32,12,141,121,3,3), @@ -695,6 +703,7 @@ int main(void) test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true); + // std::get<3>(c), std::get<4>(c), std::get<5>(c), false); ggml_gallocr_t allocr = NULL; allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); @@ -751,15 +760,16 @@ int main(void) // int i = 2048; // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + printf("(%f, %f, %f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer); From 3c6a80ffa91129ced707a875c7a94d1b035daff5 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 28 Jan 2026 11:18:36 -0500 Subject: [PATCH 118/122] turn on run whole graph to pass all test cases --- tests/test-backend-ops.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c254461215..50cc0d915d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4909,6 +4909,11 @@ struct test_conv_2d : public test_case { return 5e-4; } + // must run the whole graph for all test cases to pass + bool run_whole_graph() override { + return true; + } + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); // Just counting matmul costs: From 2d39f87bc6b1bc33d9b38305ca9d88194904530e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 29 Jan 2026 12:36:43 -0500 Subject: [PATCH 119/122] remove run_whole_graph(); not needed after the fix --- tests/test-backend-ops.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 44a22512a7..58ba6e8b40 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4909,11 +4909,6 @@ struct test_conv_2d : public test_case { return 5e-4; } - // must run the whole graph for all test cases to pass - bool run_whole_graph() override { - return true; - } - uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); // Just counting matmul costs: From 805e9ac6a89d323c78ead4abda3ddeb295528ac5 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 29 Jan 2026 14:47:11 -0500 Subject: [PATCH 120/122] remove test-conv2d --- tests/CMakeLists.txt | 1 - tests/test-conv2d.cpp | 784 ------------------------------------------ 2 files changed, 785 deletions(-) delete mode 100644 tests/test-conv2d.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 570925a689..c9436c5995 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -230,7 +230,6 @@ if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC) endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) -llama_build_and_test(test-conv2d.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp deleted file mode 100644 index f4eedb942f..0000000000 --- a/tests/test-conv2d.cpp +++ /dev/null @@ -1,784 +0,0 @@ -#include "ggml.h" -#include "ggml-alloc.h" -#include "ggml-cpu.h" -#include "ggml-backend.h" - -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -//#include -#endif - -#ifdef GGML_USE_METAL -#include "ggml-metal.h" -#endif - -#include -#include -#include -#include -#include -#include -#include -#include - -static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; - fputs(text, stderr); - fflush(stderr); -} - -struct test_model { - struct ggml_tensor * a; - struct ggml_tensor * b; - ggml_backend_t backend = NULL; - ggml_backend_buffer_t buffer; - struct ggml_context * ctx; -}; - -void load_model(test_model &, int, int, int, int, int, int, bool); -struct ggml_cgraph * build_graph_0(const test_model&); -struct ggml_cgraph * build_graph_1(const test_model&); - -void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3, int kh = 3, bool use_gpu = false ) { - // create data - int KW = kw, KH = kh, IC = ic, OC = oc; - int IW = iw, IH = ih, N = 2; - // srand(time(NULL)); - - // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); - - // Initialize adata - std::vector adata(KW * KH * IC * OC); - for (int i = 0; i < KW * KH * IC * OC; i++) { - // adata[i] = 2.f; - // adata[i] = (float)(i%KW)-1.f; - // adata[i] = (float)((i+1)%KW+1)/10.0; - // adata[i] = (float)(i%100); - // adata[i] = (rand() % 255) / 255.0; - float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - adata[i] = r; - } - - // Convert adata to fp16 format - std::vector hadata(KW * KH * IC * OC); - ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); - - // Initialize bdata - std::vector bdata(IW * IH * IC * N); - for (int i = 0; i < IW * IH * IC * N; i++) { - // bdata[i] = (float)(i%IW)/10.f; - // bdata[i] = 1.5f; - // bdata[i] = (rand() % 255) / 255.0; - float r = -1.f + static_cast (rand()) /( static_cast (RAND_MAX/(1.f-(-1.f)))); - bdata[i] = r; - } - - size_t buffer_size = 0; - { - // buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a - buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a - buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b - buffer_size += 1024; // overhead - } - - // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); - // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); - - int num_tensors = 2; - struct ggml_init_params params { - /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - // initialize the backend -#ifdef GGML_USE_CUDA - if (use_gpu) { - // fprintf(stderr, "%s: using CUDA backend\n", __func__); - model.backend = ggml_backend_cuda_init(0); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } - } -#else - GGML_UNUSED(use_gpu); -#endif - -#ifdef GGML_USE_METAL - if (use_gpu) { - fprintf(stderr, "%s: using Metal backend\n", __func__); - model.backend = ggml_backend_metal_init(); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); - } - } -#else - GGML_UNUSED(use_gpu); -#endif - - if(!model.backend) { - // fallback to CPU backend - model.backend = ggml_backend_cpu_init(); - } - - model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); - - // create context - model.ctx = ggml_init(params); - - // create tensors - model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, IC, OC); - // model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); - model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); - - // create a allocator - struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); - - // alloc memory - ggml_tallocr_alloc(&alloc, model.a); - - // load data to buffer - if(ggml_backend_is_cpu(model.backend)) { - memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a)); - // memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); - } else { - ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a)); - // ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); - } - - // alloc memory - ggml_tallocr_alloc(&alloc, model.b); - - if(ggml_backend_is_cpu(model.backend) -#ifdef GGML_USE_METAL - || ggml_backend_is_metal(model.backend) -#endif - ) { - memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); - } else { - ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); - } -} - -typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); - -struct ggml_cgraph * build_graph_0(const test_model& model) { - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - // int s0 = 1; - // int s1 = 1; - // int p0 = 1; - // int p1 = 1; - // int d0 = 1; - // int d1 = 1; - - int s0 = 1; - int s1 = 5; - int p0 = 5; - int p1 = 2; - int d0 = 2; - int d1 = 4; - - // recalculate for avoid fragmentation - struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - ggml_set_name(conv2d_res, "conv2d_res"); - ggml_build_forward_expand(gf, conv2d_res); - // int64_t *ne = conv2d_res->ne; - // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - - // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); - // ggml_set_name(wino_res, "wino_res"); - // ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; - // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - ggml_free(ctx0); - return gf; -} - -struct ggml_cgraph * build_graph_1(const test_model& model) { - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - // int s0 = 1; - // int s1 = 1; - // int p0 = 1; - // int p1 = 1; - // int d0 = 1; - // int d1 = 1; - - - // int s0 = 3; - // int s1 = 5; - // int p0 = 5; - // int p1 = 5; - // int d0 = 2; - // int d1 = 4; - - int s0 = 1; - int s1 = 5; - int p0 = 5; - int p1 = 2; - int d0 = 2; - int d1 = 4; - - - // recalculate for avoid fragmentation - // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - // ggml_set_name(conv2d_res, "conv2d_res"); - // ggml_build_forward_expand(gf, conv2d_res); - // int64_t *ne = conv2d_res->ne; - // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - - - // struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); - ggml_set_name(wino_res, "wino_res"); - ggml_build_forward_expand(gf, wino_res); - // ne = wino_res->ne; - // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); - ggml_free(ctx0); - return gf; -} - -std::vector compute_graph(const test_model &, ggml_gallocr_t, - build_graph_t, int, double *); - - -std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, - build_graph_t build_graph, int iters, double *t) { - struct ggml_cgraph * gf = build_graph(model); - - - // allocate tensors - ggml_gallocr_alloc_graph(allocr, gf); - int n_threads = 1; - - if (ggml_backend_is_cpu(model.backend)) { - ggml_backend_cpu_set_n_threads(model.backend, n_threads); - } - - ggml_backend_graph_compute(model.backend, gf); - - ggml_backend_synchronize(model.backend); - - int64_t start_time = ggml_time_us(); - - for(int iter=0; iter data(ggml_nelements(res)); - ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); - - *t = time_us/1000; - return data; - -} - -static std::vector> configs = { - // std::make_tuple(64,64,48,64,3,3), - // std::make_tuple(320,320,104,152,3,3), - // std::make_tuple(640,640,52,76,3,3), - // std::make_tuple(640,640,104,152,3,3), - // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), - // std::make_tuple(1920,640,32,32,3,3) - // std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1,1,1,133,1,1), - // std::make_tuple(32,12,141,133,3,3), - // std::make_tuple(32,6,141,133,3,3), - // std::make_tuple(32,12,141,121,3,3), - // std::make_tuple(32,9,141,121,3,3), - // std::make_tuple(320,8,16,16,3,3), //working - // std::make_tuple(320,9,16,16,3,3), //working - // std::make_tuple(320,12,16,16,3,3), //working - // std::make_tuple(256,12,16,16,3,3), //working - // std::make_tuple(32,12,16,16,3,3), //not working - // std::make_tuple(16,12,16,16,3,3), //not working - // std::make_tuple(32,12,16,16,3,3), //not working - // std::make_tuple(48,12,16,16,3,3), // not working - // std::make_tuple(96,12,16,16,3,3), //not working - // std::make_tuple(64,12,16,16,3,3), //working - // std::make_tuple(64,12,141,133,3,3), //working - // std::make_tuple(32,12,141,133,3,3), //working - // std::make_tuple(1280,1280,16,16,3,3), - // std::make_tuple(32,8,24,24,3,3), - // std::make_tuple(640,640,64,64,3,3), - // std::make_tuple(320,640,32,32,3,3), - // std::make_tuple(4,320,96,128,3,3), - // std::make_tuple(320,4,96,128,3,3), - // std::make_tuple(4,320,64,96,3,3), - // std::make_tuple(320,4,64,96,3,3), - // std::make_tuple(640,640,96,128,3,3), - // std::make_tuple(1280,1280,26,38,1,1), - // std::make_tuple(256,128,768,1024,3,3), - // std::make_tuple(128,3,768,1024,3,3), - // std::make_tuple(256,128,768,1024,1,1), - // std::make_tuple(512,256,384,512,1,1), - // std::make_tuple(1280,640,52,76,3,3), - // std::make_tuple(1920,1280,26,38,3,3), - // std::make_tuple(2560,1280,26,38,3,3), - // std::make_tuple(320,1280,26,38,3,3), - // std::make_tuple(512,512,104,152,3,3), - // std::make_tuple(512,512,208,304,3,3), - // std::make_tuple(512,256,416,608,3,3), - // std::make_tuple(256,128,832,1216,3,3), - // std::make_tuple(256,256,832,1216,3,3), - // std::make_tuple(32,64,58,58,3,3) - // std::make_tuple(320,256,1024,1920) - }; - -static std::vector> configs_sdxl_512 = { - //512x512 - std::make_tuple(4,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,4,64,64,3,3), - std::make_tuple(4,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(320,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(640,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(2560,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,16,16,3,3), - std::make_tuple(1920,1280,16,16,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1920,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(1280,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,32,32,3,3), - std::make_tuple(960,640,32,32,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(960,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,320,64,64,3,3), - std::make_tuple(640,320,64,64,3,3), - std::make_tuple(320,4,64,64,3,3) - }; - -static std::vector> configs_sdxl_768 = { - //768x768 - std::make_tuple(4,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(320,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(640,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(640,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(1920,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1920,1280,24,24,3,3), - std::make_tuple(1280,1280,48,48,3,3), - std::make_tuple(1920,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(1920,640,48,48,3,3), - std::make_tuple(1280,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(1280,640,48,48,3,3), - std::make_tuple(960,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(960,640,48,48,3,3), - std::make_tuple(640,640,96,96,3,3), - std::make_tuple(960,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(960,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(320,4,96,96,3,3), - std::make_tuple(4,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(320,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(320,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(640,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(640,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(2560,1280,24,24,3,3), - std::make_tuple(1920,1280,24,24,3,3), - std::make_tuple(1280,1280,24,24,3,3), - std::make_tuple(1920,1280,24,24,3,3), - std::make_tuple(1280,1280,48,48,3,3), - std::make_tuple(1920,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(1920,640,48,48,3,3), - std::make_tuple(1280,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(1280,640,48,48,3,3), - std::make_tuple(960,640,48,48,3,3), - std::make_tuple(640,640,48,48,3,3), - std::make_tuple(960,640,48,48,3,3), - std::make_tuple(640,640,96,96,3,3), - std::make_tuple(960,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(960,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(320,320,96,96,3,3), - std::make_tuple(640,320,96,96,3,3), - std::make_tuple(320,4,96,96,3,3), - }; - -static std::vector> configs_sdxl_1024 = { - //1024x1024 - std::make_tuple(4,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,4,128,128,3,3), - std::make_tuple(4,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(320,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(640,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(2560,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,32,32,3,3), - std::make_tuple(1920,1280,32,32,3,3), - std::make_tuple(1280,1280,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1920,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(1280,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,64,64,3,3), - std::make_tuple(960,640,64,64,3,3), - std::make_tuple(640,640,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(960,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,320,128,128,3,3), - std::make_tuple(640,320,128,128,3,3), - std::make_tuple(320,4,128,128,3,3) - }; - - -int main(void) -{ - ggml_time_init(); - - double time_iter0 = 0.0, time_iter1 = 0.0; - - int k = 0; - - // for (auto c : configs_sdxl_1024){ - for (auto c : configs){ - test_model model; - load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), - std::get<3>(c), std::get<4>(c), std::get<5>(c), true); - // std::get<3>(c), std::get<4>(c), std::get<5>(c), false); - - ggml_gallocr_t allocr = NULL; - allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - - //create the worst case graph for memory usage estimation - struct ggml_cgraph * gf = build_graph_0(model); - - // compute the required memory - ggml_gallocr_reserve(allocr, gf); - size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - - int iterations = 0; - - double run_time0; - std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); - - ggml_gallocr_free(allocr); - - allocr = NULL; - - allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - - //create the worst case graph for memory usage estimation - gf = build_graph_1(model); - - // compute the required memory - ggml_gallocr_reserve(allocr, gf); - size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); - // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - - - - double run_time1; - // std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); - std::vector conv2d_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); - - if(k==0) { - k = 1; - fprintf(stdout, "| (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); - fprintf(stdout, "| --- | --- | --- | --- | --- \n"); - } - - time_iter0 += run_time0; - time_iter1 += run_time1; - - - fprintf(stdout, " | (%d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", - std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), - run_time0, mem_size0/1024.0f/1024.0f, - run_time1, mem_size1/1024.0f/1024.0f); - - // int i = 2048; - // for(int i = 0; i < ggml_nelements(wino_res); i++) { - // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - printf("(%f, %f, %f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } - - ggml_free(model.ctx); - ggml_backend_buffer_free(model.buffer); - ggml_backend_free(model.backend); - ggml_gallocr_free(allocr); - - } - printf("| 1 unet iter takes| %.2f ms | | %.2f ms | \n", time_iter0, time_iter1); - - // printf("\nPerforming test:\n"); - return 0; -} From 75e5a9bd0159e86a78b851b6b28a387af648e5de Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 29 Jan 2026 16:08:56 -0500 Subject: [PATCH 121/122] fix to work for Turing --- ggml/src/ggml-cuda/conv2d-implicit.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index e73ec150df..67f6cb33b4 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -344,7 +344,7 @@ __device__ __forceinline__ void ldmatrix_a( const half* src, half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] ){ -#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#ifdef CP_ASYNC_AVAILABLE static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); @@ -503,7 +503,7 @@ __device__ __forceinline__ void ldmatrix_b( const half* src, half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] ){ -#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#ifdef CP_ASYNC_AVAILABLE static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); @@ -981,8 +981,7 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, param_t P, cudaStream_t st) { - // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { - if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r <= 32 && P.s <= 32)) { + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.c % 8 == 0 && (P.r <= 32 && P.s <= 32)) { int id = ggml_cuda_get_device(); From 8744a9f7fa24d7b79f26d7cb339f0fd24c61d031 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Fri, 30 Jan 2026 12:48:34 -0500 Subject: [PATCH 122/122] exclude HIP from using tensor core --- ggml/src/ggml-cuda/conv2d-implicit.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 67f6cb33b4..37144970d3 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -596,7 +596,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, T * __restrict__ output, const param_t param) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8;