diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 9b2331876b..1e6eee88c9 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -27,6 +27,7 @@ static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restri } } + template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -63,6 +64,8 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } + + template){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; - smemweight[weight_sts_addr + offset + 0] = tmp.x; - smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; - smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; - smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; - }else{ // read 4 halves - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; - const half *val = reinterpret_cast(&tmp); - smemweight[weight_sts_addr + offset + 0] = val[0]; - smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; - smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; - smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i){ - smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - } - }else{ -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){ - smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i]; - } else { - smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - } - } - } + loadFilter + (kernel, smemweight, by, innerRowA, innerColA, weightKOffset, + start_k, end_k, param); + loadInput + (input, smeminput, bx, innerRowA, innerColA, + start_k, end_k, PQ, CHW, inChannelOffset, param); - const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; -#pragma unroll - for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; - if(vec_load){ - const uint cur0 = fastdiv(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; - smeminput[input_sts_addr + offset + 0] = tmp.x; - smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; - smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; - smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; - } else { -#pragma unroll - for (int i = 0; i < 4; ++i) - smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; - } else { - smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } - } - } __syncthreads(); // lds @@ -279,106 +190,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } } // ldg -#pragma unroll - for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { - if(vec_load){ - if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){ - if constexpr (std::is_same_v){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; - } else { - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; - const half *val = reinterpret_cast(&tmp); - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i) - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - }else{ -#pragma unroll - for (int i = 0; i < 4; ++i){ - if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){ - // float4 tmp = reinterpret_cast(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0]; - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i]; - } else { - smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; - } - } - } - } -#pragma unroll - for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z; - const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ; - const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p; - const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; - int inOffset = n * param.c * param.h * param.w ; - if(vec_load){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ - // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; - } else { -#pragma unroll - for (int i = 0; i < 4; ++i) - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } else { -#pragma unroll - for (int i = 0; i < 4; ++i){ - const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset - const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, - layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), - layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset - const uint curC = layout == 0 ? cur2 : cur0; - const uint curR = layout == 0 ? cur0 : cur1; - const uint curS = layout == 0 ? cur1 : cur2; + loadFilter + (kernel, &smemweight[write_flag * (BN+PAD) * BK], by, innerRowA, innerColA, weightKOffset, + crs+BK, end_k, param); + + loadInput + (input, &smeminput[write_flag * (BM+PAD) * BK], bx, innerRowA, innerColA, + crs + BK, end_k, PQ, CHW, inChannelOffset, param); - const int curH = posh_ori + curR * param.d_h; // input h - const int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ - int inOffsetTmp = layout == 0 ? - curH * inChannelOffset + curW * param.c + curC: - curC * inChannelOffset + curH * param.w + curW; - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; - } else { - smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f; - } - } - } - } __syncthreads(); write_flag ^= 1; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index a49210ddd8..85936e42c6 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -359,6 +359,139 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { return address; } +template +__device__ __forceinline__ void loadFilter(const T * __restrict__ kernel, + T * __restrict__ smemweight, + const unsigned int by, + const unsigned int innerRowA, + const unsigned int innerColA, + const unsigned int weightKOffset, + const unsigned int start_k, + const unsigned int end_k, + const param_t param){ + + const unsigned int weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; + const unsigned int kidx = start_k + innerColA * 4; +#pragma unroll + for (int offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + const unsigned int nidx = by * BN + innerRowA + offset; + if (vec_load) { + if (nidx < param.k && kidx < end_k) { + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[nidx * weightKOffset + kidx])[0]; + smemweight[weight_sts_addr + offset + 0] = tmp.x; + smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + } else { // read 4 halves + float2 tmp = reinterpret_cast(&kernel[nidx * weightKOffset + kidx])[0]; + const half *val = reinterpret_cast(&tmp); + smemweight[weight_sts_addr + offset + 0] = val[0]; + smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + if (nidx < param.k && kidx + i < end_k) { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[nidx * weightKOffset + kidx + i]; + } else { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } + } +} + + +template +__device__ __forceinline__ void loadInput(const float * __restrict__ input, + float * __restrict__ smeminput, + const unsigned int bx, + const unsigned int innerRowA, + const unsigned int innerColA, + const unsigned int start_k, + const unsigned int end_k, + const unsigned int PQ, + const unsigned int CHW, + const unsigned int inChannelOffset, + const param_t param) { + const unsigned int input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; + const unsigned int kidx = start_k + innerColA * 4; +#pragma unroll + for (unsigned int offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const unsigned int midx = bx * BM + innerRowA + offset; + int n = (ksplit > 0) ? midx / PQ : blockIdx.z; + const unsigned int npq_res = midx % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.v - param.q; + const unsigned int inOffset = n * CHW; + if (vec_load) { + const unsigned int cur0 = fastdiv(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const unsigned int cur1 = fastdiv(fastmodulo(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int cur2 = fastmodulo(fastmodulo(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int curC = layout == 0 ? cur2 : cur0; + const unsigned int curR = layout == 0 ? cur0 : cur1; + const unsigned int curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx < end_k) { + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[input_sts_addr + offset + 0] = tmp.x; + smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; + smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; + smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + const unsigned int cur0 = fastdiv(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const unsigned int cur1 = fastdiv(fastmodulo(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int cur2 = fastmodulo(fastmodulo(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int curC = layout == 0 ? cur2 : cur0; + const unsigned int curR = layout == 0 ? cur0 : cur1; + const unsigned int curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx + i < end_k) { + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; + } else { + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } + } + } +} + #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);