From 16b0f0ae3c76576bf6f325951f0bb0332ce70f06 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 13 Oct 2025 18:41:30 -0400 Subject: [PATCH] work in progress --- ggml/src/ggml-cuda/conv2d-implicit.cu | 782 +++++++++++++++++++------ ggml/src/ggml-cuda/conv2d-implicit.cuh | 22 + 2 files changed, 625 insertions(+), 179 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index f2af27a7fb..0a5c370f29 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,172 +1,338 @@ #include "conv2d-implicit.cuh" #include "convert.cuh" -typedef struct{ - unsigned int n; //batch size - unsigned int c; //number if channels - unsigned int h; //height - unsigned int w; //width - unsigned int k; //number of filters - unsigned int r; //filter height - unsigned int s; //filter width - unsigned int u; //stride height - unsigned int v; //stride width - unsigned int p; //padding height - unsigned int q; //padding width - unsigned int d_h; //dilation height - unsigned int d_w; //dilation width - unsigned int Oh; //output height - unsigned int Ow; //output width -} param_t; + +static const int WARPSIZE = 32; // warpSize is not constexpr + +static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + if (row * blockDim.x + col < ncols) { + for (int i = 0; i < nrows; ++i){ + sum += x[i * ncols + row * blockDim.x + col]; + } + dst[row * blockDim.x + col] = sum; + } +} - -template +template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, const param_t param) { - extern __shared__ unsigned char smem[]; + // __shared__ char smem[4 * (TM*TN*NUM_THREADS <= (BM * BK + BK * (BN+PAD)) ? (BM * BK + BK * (BN+PAD)) : (TM*TN*NUM_THREADS))]; + __shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * BM * BK + sizeof(T)*2*BK * (BN+PAD) ? + sizeof(float)*2*BM*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; + // __shared__ float smeminput[2 * BM * BK]; + // __shared__ float smemweight[2 * BK * (BN+PAD)]; T *smemweight = reinterpret_cast(smem); - float *smeminput = reinterpret_cast(smem + 16 * 1024); + float *smeminput = reinterpret_cast(smem + 2 * BK * (BN+PAD) * sizeof(T)); + + const uint tx = threadIdx.x; + const uint bx = blockIdx.x; + const uint by = blockIdx.y; + + const uint PQ = param.Oh * param.Ow; - int tx = threadIdx.x; - int bx = blockIdx.x; - int by = blockIdx.y; - - // Warp tile - const int lane_id = threadIdx.x & 31; - const int warp_id = threadIdx.x >> 5; - const int mma_tid_x = (lane_id >> 1) % 8; - const int mma_tid_y = (lane_id >> 4) * 2 + (lane_id & 1); - // lds addr - int weight_lds_addr = (warp_id >> 1) * 32 + mma_tid_y * 4; - int input_lds_addr = (warp_id & 1) * 64 + mma_tid_x * 4; + const uint lane_id = tx % WARPSIZE; + const uint warp_id = tx / WARPSIZE; + const int mma_tid_x = warp_id / (BN / WN); //(lane_id / 2) % 8; + const int mma_tid_y = warp_id % (BN / WN); //(lane_id / 16) * 2 + (lane_id % 2); - // int x = bx * 128 + input_lds_addr; - // int y = by * 128 + weight_lds_addr; + // lds addr + // int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; + // int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + + // size of the warp subtile + constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); + constexpr uint WSUBM = WM / WMITER; // 64/2=32 + constexpr uint WSUBN = WN / WNITER; // 32/2=16 + + // Placement of the thread in the warp subtile + // const uint threadIdxInWarp = tx % WARPSIZE; // [0, 31] + const uint threadColInWarp = lane_id % (WSUBN / TN); // i%(16/4) + const uint threadRowInWarp = lane_id / (WSUBN / TN); // i/4 + + // int x = bx * BM + input_lds_addr; + // int y = by * BN + weight_lds_addr; int z = blockIdx.z; - T weight_ldg_reg[4]; - float input_ldg_reg[4]; - - int posh_ori[4]; - int posw_ori[4]; -#pragma unroll - for (int i = 0; i < 4; ++i){ - posh_ori[i] = ((bx * 128 + lane_id + i * 32) / param.Ow) * param.u - param.p; - posw_ori[i] = ((bx * 128 + lane_id + i * 32) % param.Ow) * param.v - param.q; - } - int inOffset = z * param.c * param.h * param.w; - int weiOffset = (by * 128 + (tx >> 3) * 4) * param.c * param.r * param.s; - int inChannelOffset = param.h * param.w; + // float weight_ldg_reg[4]; + // float input_ldg_reg[4]; + // 当前线程处理的数据点在oh、ow上的坐标 + // int posh_ori = ((bx * 128 + tx / 2 ) / param.Ow) * param.u - param.p; + // int posw_ori = ((bx * 128 + tx / 2 ) % param.Ow) * param.v - param.q; + // int posh_ori = fastdiv(bx * BM + tx / 2, param.OW_fastdiv) * param.u - param.p; + // int posw_ori = fastmodulo(bx * BM + tx / 2, param.OW_fastdiv) * param.v - param.q; + + + // int inOffset = (ksplit > 0): z * param.c * param.h * param.w ; + // int weiOffset = (by * BN + tx / 8 * 4) * param.c * param.r * param.s; + int inChannelOffset = param.c * param.w; // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; - // sts addr - int weight_sts_addr = (tx & 7) * 132 + - (tx >> 3) * 4; - int input_sts_addr = (warp_id) * 128 + (lane_id); + // uint ks, start_k; + // if constexpr (ksplit > 0){ + // const uint ks = (weightKOffset + ksplit - 1) / ksplit; + // const uint start_k = z * ks; + // } else { + // const uint ks = weightKOffset; + // const uint start_k = 0; + // } + const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const uint start_k = (ksplit > 0)? z * ks: 0; + const uint end_k = min(start_k + ks, weightKOffset); + + // sts addr + // int weight_sts_addr = (tx % 8) * 132 + + // (tx / 8) * 4; int write_flag = 1; - T weight_frag[2][8]; - float input_frag[2][8]; - float output_frag[8][8]; -#pragma unroll - for (int i = 0; i < 8; ++i){ -#pragma unroll - for (int j = 0; j < 8; ++j){ - output_frag[i][j] = 0; - } - } + T weight_frag[2][WNITER * TN]; + float input_frag[2][WMITER * TM] = {0.f}; + float output_frag[WMITER * TM * WNITER * TN] = {0.f}; +// #pragma unroll +// for (int i = 0; i < 8; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 8; ++j) +// { +// output_frag[i][j] = 0; +// } +// } + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = tx / (BK / 4); + const uint innerColA = tx % (BK / 4); + constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; + // const uint innerRowB = tx / (BN / 4); + // const uint innerColB = tx % (BN / 4); + // constexpr uint rowStrideB = NUM_THREADS / (BN / 4); + // ldg -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (tx % 8 < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ - weight_ldg_reg[i] = kernel[weiOffset + (tx & 7) + i * weightKOffset]; - } - else{ - weight_ldg_reg[i] = (T)0.f; + const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; + for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + if(vec_load){ + // if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){ + if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){ + if constexpr (std::is_same_v(T, float)){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + smemweight[weight_sts_addr + offset + 0] = tmp.x; + smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + }else{ // read 4 halves + // half val[4]; + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + half *val = reinterpret_cast(&tmp); + // val[1] = reinterpret_cast(&tmp.y); + smemweight[weight_sts_addr + offset + 0] = val[0]; + smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i){ + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + }else{ + #pragma unroll + for (int i = 0; i < 4; ++i){ + if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){ + // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4])[0]; + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i]; + } else { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } } } - int curC = (warp_id) / (param.r * param.s); // channel offset - int curR = ((warp_id) % (param.r * param.s)) / param.s; // kernel r offset - int curS = ((warp_id) % (param.r * param.s)) % param.s; // kernel s offset -#pragma unroll - for (int i = 0; i < 4; ++i){ - int curH = posh_ori[i] + curR * param.d_h; // input h - int curW = posw_ori[i] + curS * param.d_w; // input w - int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ - input_ldg_reg[i] = input[inOffset + inOffsetTmp]; - } - else{ - input_ldg_reg[i] = 0.0; + + + // int curC = (tx / 32) / (param.r * param.s); // channel offset + // int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset + // int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + + // int curR = (tx % 2) * 4 / (param.s * param.c); // channel offset + // int curS = ((tx % 2) * 4 % (param.s * param.c)) / param.c; // kernel r offset + // int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset + + const uint input_sts_addr = innerRowA + innerColA * BM * 4; + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + int inOffset = n * param.c * param.h * param.w ; + if(vec_load){ + const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[input_sts_addr + offset + 0] = tmp.x; + smeminput[input_sts_addr + offset + BM] = tmp.y; + smeminput[input_sts_addr + offset + 2*BM] = tmp.z; + smeminput[input_sts_addr + offset + 3*BM] = tmp.w; + } else { + #pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[input_sts_addr + offset + i*BM] = 0.f; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i){ + const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + smeminput[input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; + } else { + smeminput[input_sts_addr + offset + i*BM] = 0.f; + } + } } } + // sts - for (int i = 0; i < 4; ++i){ - smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; - } - for (int i = 0; i < 4; ++i){ - smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; - } + // for (int i = 0; i < 4; ++i) + // { + // smemweight[weight_sts_addr + i*132] = weight_ldg_reg[i]; + // } + // for (int i = 0; i < 4; ++i) + // { + // smeminput[input_sts_addr + i * 128] = input_ldg_reg[i]; + // } __syncthreads(); + + // if(tx == 0 && bx == 0 && by == 0 && z == 0){ + // for(int i=0; i < 128; ++i) + // printf("%.2f,", smeminput[i]); + // printf("\n"); + // for(int i=128; i < 256; ++i) + // printf("%.2f,", smeminput[i]); + // printf("\n"); + // } + + // if(tx == 0 && bx == 0 && by == 0 && z == 0){ + // printf("%u, %u, %u, %u \n", innerRowA, innerColA, rowStrideA, weight_sts_addr); + // for(int i=0; i < 16; ++i) + // printf("%f,", smemweight[i]); + // printf("\n"); + // for(int i=0; i < 16; ++i) + // printf("%f,", param.weight[i*param.c*param.r*param.s]); + // printf("\n"); + // } + // lds -#pragma unroll - for (int i = 0; i < 4; ++i){ - weight_frag[0][i] = smemweight[weight_lds_addr + i]; - weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; - } -#pragma unroll - for (int i = 0; i < 4; ++i){ - input_frag[0][i] = smeminput[input_lds_addr + i]; - input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; - } + // int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + const uint input_lds_addr = mma_tid_x * WM; + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; - // main loop - for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ + // int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; + const uint weight_lds_addr = mma_tid_y * WN; + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + + threadColInWarp * TN + i]; + +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// weight_frag[0][i] = smemweight[weight_lds_addr + i]; +// weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; +// } + // if(tx == 0 && bx == 0 && by == 0 && z == 0) + // { + // printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][0], weight_frag[0][1], weight_frag[0][2], weight_frag[0][3]); + // printf("weight_ldg_reg:%f,%f,%f,%f\n", weight_frag[0][4], weight_frag[0][5], weight_frag[0][6], weight_frag[0][7]); + // } +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// input_frag[0][i] = smeminput[input_lds_addr + i]; +// input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; +// } + + + for (int crs = start_k; crs < end_k; crs += BK) + { // ldg - int weiOffsetTmp = crs + 8 + (tx & 7); -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (weiOffsetTmp < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ - weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; - } - else{ - weight_ldg_reg[i] = (T)0.f; - } - } - curC = (crs + 8 + warp_id) / (param.r * param.s); // channel offset - curR = ((crs + 8 + warp_id) % (param.r * param.s)) / param.s; // kernel r offset - curS = ((crs + 8 + warp_id) % (param.r * param.s)) % param.s; // kernel s offset +// if (by * BN + tx / 2 < param.k && tx % 2 * 4 < param.c * param.r * param.s){ +// float4 tmp = reinterpret_cast(¶m.weight[by * BN + tx / 2 * weightKOffset + tx % 2 * 4 + crs + 8])[0]; +// weight_ldg_reg[0] = tmp.x; +// weight_ldg_reg[1] = tmp.y; +// weight_ldg_reg[2] = tmp.z; +// weight_ldg_reg[3] = tmp.w; +// } else { +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// weight_ldg_reg[i] = 0.0; +// } + // curR = (crs + 8 + tx % 2 * 4) / (param.s * param.c); // channel offset + // curS = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) / param.c; // kernel r offset + // curC = ((crs + 8 + tx % 2 * 4) % (param.s * param.c)) % param.c; // kernel s offset +// curR = fastdiv(crs + 8 + (tx % 2) * 4, param.SC_fastdiv); // channel offset +// curS = fastdiv(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset +// curC = fastmodulo(fastmodulo(crs + 8 + (tx % 2) * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + +// int curH = posh_ori + curR; // input h +// int curW = posw_ori + curS; // input w +// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h){ +// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + +// // float4 tmp = reinterpret_cast(¶m.input[inOffset + inOffsetTmp])[0]; +// // input_ldg_reg[0] = tmp.x; +// // input_ldg_reg[1] = tmp.y; +// // input_ldg_reg[2] = tmp.z; +// // input_ldg_reg[3] = tmp.w; +// reinterpret_cast(&input_ldg_reg[0])[0] = reinterpret_cast(¶m.input[inOffset + inOffsetTmp])[0]; } else { +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// input_ldg_reg[i] = 0.0; +// } -#pragma unroll - for (int i = 0; i < 4; ++i){ - int curH = posh_ori[i] + curR * param.d_h; // input h - int curW = posw_ori[i] + curS * param.d_w; // input w - int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ - input_ldg_reg[i] = input[inOffset + inOffsetTmp]; - } - else{ - input_ldg_reg[i] = 0.f; - } - } int load_flag = write_flag ^ 1; #pragma unroll - for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ + for (int subcrs = 0; subcrs < BK - 1; ++subcrs) + { +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i]; +// weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * (BN+4) * 8 + weight_lds_addr + (subcrs + 1) * (BN+4) + i + 16]; +// } #pragma unroll - for (int i = 0; i < 4; ++i){ - weight_frag[(subcrs + 1) & 1][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; - weight_frag[(subcrs + 1) & 1][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; - } - // // compute base pointer once - // T* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[(subcrs + 1) % 2][wSubColIdx * TN + i] = smemweight[load_flag * (BN+PAD) * BK + + (subcrs + 1) * (BN+PAD) + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; + // float* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; // // first 4 values -> weight_frag[...][0..3] // float4 v0 = *reinterpret_cast(base_ptr); @@ -177,92 +343,347 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // // unpack into weight_frag // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][0]) = v0; // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][4]) = v1; -#pragma unroll - for (int i = 0; i < 4; ++i){ - input_frag[(subcrs + 1) & 1][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; - input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; - } - // #pragma unroll -// for (int i = 0; i < 8; ++i){ -// auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); -// #pragma unroll -// for (int j = 0; j < 8; ++j){ -// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j]; -// } +// for (int i = 0; i < 4; ++i) +// { +// input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i]; +// input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * BM * 8 + input_lds_addr + (subcrs + 1) * BM + i + 32]; // } #pragma unroll - for (int j = 0; j < 8; ++j){ - // auto weight_frag_i = ggml_cuda_cast(weight_frag[subcrs % 2][i]); + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll - for (int i = 0; i < 8; ++i){ - output_frag[j][i] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; + for (uint i = 0; i < TM; ++i) + input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * BM * BK + + (subcrs + 1) * BM + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; + +// #pragma unroll +// for (int i = 0; i < 8; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 8; ++j) +// { +// output_frag[i][j] += weight_frag[subcrs % 2][i] * input_frag[subcrs % 2][j]; +// } +// } + // execute warptile matmul +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN]); + // if(tx == 0 && bx == 0 && by == 0 && z == 0){ + // printf("subcrs:%d, i:%d, j:%d, %f * %f = %f, acc = %f\n", subcrs, wSubRowIdx * TM + resIdxM, wSubColIdx * TN + resIdxN, + // input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM], + // weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN], + // input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * + // weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN], + // output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + // (wSubColIdx * TN) + resIdxN]); + // } + } + } + } + } + } + // ldg +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + if(vec_load){ + if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){ + if constexpr (std::is_same_v(T, float)){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + } else { + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + half *val = reinterpret_cast(&tmp); + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i) + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + }else{ + #pragma unroll + for (int i = 0; i < 4; ++i){ + if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){ + // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0]; + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i]; + } else { + smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } + } +#pragma unroll + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; + const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; + int inOffset = n * param.c * param.h * param.w ; + if(vec_load){ + const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[write_flag * BM * BK + input_sts_addr + offset + 0] = tmp.x; + smeminput[write_flag * BM * BK + input_sts_addr + offset + BM] = tmp.y; + smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z; + smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w; + } else { + #pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; + } + } else { + #pragma unroll + for (int i = 0; i < 4; ++i){ + const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + + const int curH = posh_ori + curR; // input h + const int curW = posw_ori + curS; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ + int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; + } else { + smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; + } } } } // sts - for (int i = 0; i < 4; ++i){ - smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; - } - for (int i = 0; i < 4; ++i){ - smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; - } + // for (int i = 0; i < 4; ++i) + // { + // smemweight[write_flag * (BN+4) * 8 + weight_sts_addr + i * (BN+4)] = weight_ldg_reg[i]; + // } + // for (int i = 0; i < 4; ++i) + // { + // smeminput[write_flag * BM * 8 + input_sts_addr + i * BM] = input_ldg_reg[i]; + // } __syncthreads(); write_flag ^= 1; #pragma unroll - for (int i = 0; i < 4; ++i){ - weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; - weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; - } + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) #pragma unroll - for (int i = 0; i < 4; ++i){ - input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; - input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; - } + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * BM * BK + + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; #pragma unroll - for (int i = 0; i < 8; ++i){ + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) #pragma unroll - for (int j = 0; j < 8; ++j){ - output_frag[i][j] += ggml_cuda_cast(weight_frag[1][j]) * input_frag[1][i]; + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[(load_flag ^ 1) * (BN+PAD) * BK + + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// weight_frag[0][i] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i]; +// weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * (BN+4) * 8 + weight_lds_addr + i + 16]; +// } +// #pragma unroll +// for (int i = 0; i < 4; ++i) +// { +// input_frag[0][i] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i]; +// input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * BM * 8 + input_lds_addr + i + 32]; +// } +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[1][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[1][wSubColIdx * TN + resIdxN]); + } + } } } +// #pragma unroll +// for (int i = 0; i < 8; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 8; ++j) +// { +// output_frag[i][j] += weight_frag[1][i] * input_frag[1][j]; +// } +// } } + // if(tx == 59 && bx == 0 && by == 0 && z == 0){ + // for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){ + // printf("%f,", output_frag[i]); + // if((i+1) % (WNITER * TN) == 0) + // printf("\n"); + // } + // printf("\n"); + // } + // if(tx == 59 && bx == 0 && by == 0 && z == 0){ + // int cnt[3] = {0}; + // float values[3] = {-1.f}; + // for (int i = 0; i < WMITER * TM * WNITER * TN; ++i){ + // for(int j = 0; j < 3; j++){ + // if (output_frag[i] == values[j]){ + // cnt[j]++; + // break; + // } else{ + // if (cnt[j] == 0){ + // values[j] = output_frag[i]; + // cnt[j]++; + // break; + // } + // } + // } + // } + // for(int j = 0; j < 3; j++){ + // if(values[j] != -1.f) + // printf("value: %f, cnt: %d \n", values[j], cnt[j]); + // } + // } + // reuse smem float *smemoutput = reinterpret_cast(smem); + // float *smembias = reinterpret_cast(smem + 16 * 1024); + + // bias ldg/sts + // if (tx < BN) + // { + // smembias[tx] = param.bias[by * BN + tx]; + // } + + // constexpr uint OUTMITER = (TM * TN * WNITER * WMITER * NUM_THREADS) / (2 * BK * (BM + BN)) / OUTNITER; + // const uint WMITER_TM_OUTMITER = WMITER * TM / OUTMITER; + // const uint WNITER_TN_OUTNITER = WNITER * TN / OUTNITER; - uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; - uint32_t output_lds_addr = warp_id * 512 + lane_id; - uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; - uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; +// // uint32_t bias_lds_addr = warp_id / 2 * 32; + +// #pragma unroll +// for (int i = 0; i < 2; ++i) +// { +// #pragma unroll +// for (int j = 0; j < 2; ++j) +// { +// __syncthreads(); + +// #pragma unroll +// for (int subi = 0; subi < 4; ++subi) +// { +// #pragma unroll +// for (int subj = 0; subj < 4; ++subj) +// { +// // output sts +// smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; +// } +// } +// __syncthreads(); + +// #pragma unroll +// for (int subk = 0; subk < 16; ++subk) +// { +// int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; +// if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) +// param.output[outOffset] = smemoutput[output_lds_addr + subk * 32]; +// } +// } +// } + const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + // const uint m_idx = by * BN + mma_tid_y * WN + threadColInWarp * WNITER_TN_OUTNITER; + // const uint n_idx = bx * BM + mma_tid_x * WM + threadRowInWarp * WMITER_TM_OUTMITER; + // const uint output_sts_addr = warp_id * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER * WARPSIZE + + // (threadRowInWarp * (WSUBN / TN) + threadColInWarp) * WMITER_TM_OUTMITER * WNITER_TN_OUTNITER; + const uint output_sts_addr = mma_tid_x * BN / WN * TM * TN * WARPSIZE + mma_tid_y * TM * TN * WARPSIZE + + threadColInWarp * TN * WSUBM + threadRowInWarp * TM; + const uint m_idx = by * BN + mma_tid_y * WN; + const uint n_idx = bx * BM + mma_tid_x * WM; #pragma unroll - for (int i = 0; i < 2; ++i){ + for (int i = 0; i < WMITER; ++i) + { #pragma unroll - for (int j = 0; j < 2; ++j){ + for (int j = 0; j < WNITER; ++j) + { __syncthreads(); + #pragma unroll - for (int subi = 0; subi < 4; ++subi){ + for (int subi = 0; subi < TM; ++subi) + { #pragma unroll - for (int subj = 0; subj < 4; ++subj){ + for (int subj = 0; subj < TN; ++subj) + { // output sts - smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj]; + smemoutput[output_sts_addr + subj * WSUBM + subi] = + output_frag[(i * TM + subi) * (WNITER * TN) + j * TN + subj]; } } __syncthreads(); - #pragma unroll - for (int subk = 0; subk < 16; ++subk){ - int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32; - if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow) - output[outOffset] = smemoutput[output_lds_addr + subk * 32]; + for (int subk = 0; subk < TM * TN; ++subk){ + const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; + const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; + const int n = (ksplit > 0) ? gemm_i / PQ : z; + const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; + if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); + // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; + const uint outOffset = ksplit > 0 ? + z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col : + z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + } } } } } +#define NUM_VARIANTS 6 + +/* + conv_shapes[][0]: ne_input=[384,512,256,1],ne_kernel=[3,3,256,256] + conv_shapes[][1]: ne_input=[96,128,512,1],ne_kernel=[3,3,512,512] + conv_shapes[][2]: ne_input=[192,256,512,1git diff],ne_kernel=[3,3,512,512] +*/ +constexpr static int conv_shapes[][NUM_VARIANTS] = { + { 128, 128, 128 }, // BM + { 256, 128, 256 }, // BN + { 8, 8, 8 }, // BK + { 128, 64, 32 }, // WM + { 32, 32 , 256 }, // WN + { 2, 2, 1 }, // WNITER + { 8, 4, 8 }, // TM + { 8, 4, 4 }, // TN + { 256, 256, 128} // NUM_THREADS +}; + template static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number @@ -324,6 +745,9 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t total = B * OC * OH * OW; param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; + params.SC_fastdiv = init_fastdiv_values(KW*KH); + params.OW_fastdiv = init_fastdiv_values(OW); + params.C_fastdiv = init_fastdiv_values(IC); if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 46161feb3c..4fe6134873 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -1,5 +1,27 @@ #pragma once #include "common.cuh" +typedef struct{ + unsigned int n; //batch size + unsigned int c; //number if channels + unsigned int h; //height + unsigned int w; //width + unsigned int k; //number of filters + unsigned int r; //filter height + unsigned int s; //filter width + unsigned int u; //stride height + unsigned int v; //stride width + unsigned int p; //padding height + unsigned int q; //padding width + unsigned int d_h; //dilation height + unsigned int d_w; //dilation width + unsigned int Oh; //output height + unsigned int Ow; //output width + uint3 SC_fastdiv; + uint3 OW_fastdiv; + uint3 C_fastdiv; +} param_t; + + #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);