unroll some loops

This commit is contained in:
bssrdf 2025-10-15 12:46:46 -04:00
parent b70cca2ea3
commit 3f99818925
2 changed files with 13 additions and 6 deletions

View File

@ -4,8 +4,9 @@
#include "convert.cuh"
#include "conv2d-implicit.cuh"
static const int WARPSIZE = 32; // warpSize is not constexpr
typedef unsigned int uint;
constexpr uint WARPSIZE = 32;
static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.x;
@ -125,6 +126,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
// ldg
const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
if(vec_load){
// if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){
@ -174,6 +176,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
// int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset
const uint input_sts_addr = innerRowA + innerColA * BM * 4;
#pragma unroll
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
@ -278,14 +281,18 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
// lds
// int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4;
const uint input_lds_addr = mma_tid_x * WM;
#pragma unroll
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
#pragma unroll
for (uint i = 0; i < TM; ++i)
input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM +
threadRowInWarp * TM + i];
// int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4;
const uint weight_lds_addr = mma_tid_y * WN;
#pragma unroll
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
#pragma unroll
for (uint i = 0; i < TN; ++i)
weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN +
threadColInWarp * TN + i];
@ -495,7 +502,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z;
smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w;
} else {
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; ++i)
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f;
}
@ -781,11 +788,11 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
}
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<half, 3>(X_D, K_D, Y_D, P, st);
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
}
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<float, 3>(X_D, K_D, Y_D, P, st);
conv2d_implicit_cuda<float, 1>(X_D, K_D, Y_D, P, st);
}
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -811,7 +818,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
const int LT = p[6]; // layout
GGML_ASSERT(LT == 0 || LT == 1);
// same number of input channels
GGML_ASSERT(LT == 0 ? input->ne[0] == kernel->ne[0] : input->ne[2] == kernel->ne[2]);
// No cwhn

View File

@ -262,7 +262,7 @@ struct ggml_cgraph * build_graph_2(const test_model& model) {
// printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 0);
struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 1);
// struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
ggml_set_name(wino_res, "wino_res");
ggml_build_forward_expand(gf, wino_res);