From 66f6d16265cc76984efe8a4b4d739c6ca2a0e0e0 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 23 Oct 2025 13:52:26 -0400 Subject: [PATCH] WIP --- ggml/src/ggml-cuda/conv2d-implicit.cu | 127 ++++++++------ ggml/src/ggml-cuda/conv2d-implicit.cuh | 232 ++++++++++++++++++++++++- 2 files changed, 301 insertions(+), 58 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1866517775..a11d306c6c 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -892,27 +892,43 @@ __device__ __forceinline__ void ldmatrix_b( static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); - uint32_t (®_) [4][8] = reinterpret_cast(reg); - const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); - unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); +// uint32_t (®_) [4][8] = reinterpret_cast(reg); +// const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); +// unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); +// uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); +// constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + +// asm volatile ( +// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " +// "{%0, %1, %2, %3}, [%4];" +// : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) +// : "r"(src_addr) +// ); + + // 0 asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); - src_addr += 8 * smem_stride_; + src_addr ^= 0b10000; asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " @@ -925,10 +941,12 @@ __device__ __forceinline__ void ldmatrix_b( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); - src_addr += 8 * smem_stride_; +// src_addr += 8 * smem_stride_; + src_addr ^= 0b110000; asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " @@ -941,10 +959,11 @@ __device__ __forceinline__ void ldmatrix_b( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); - src_addr += 8 * smem_stride_; + src_addr ^= 0b10000; asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " @@ -957,7 +976,8 @@ __device__ __forceinline__ void ldmatrix_b( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) - : "r"(src_addr ^ 0b1000000) + // : "r"(src_addr ^ 0b1000000) + : "r"(src_addr + 32 * smem_stride_) ); } @@ -1038,7 +1058,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // prefetch the first block tile of A,B into shared memory // half* A_block_gmem = input + (block_m * BM * A_stride); half* A_block_gmem = input; - half* B_block_gmem = weight + (block_n * weightKOffset); + half* B_block_gmem = kernel + (block_n * weightKOffset); tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); @@ -1053,16 +1073,17 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k) { - half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); - half* B_block_gmem = B + (block_k * BK * B_stride) + (block_n * BN); - tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, K); - tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, N); + // half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK); + half* A_block_gmem = input; + half* B_block_gmem = kernel + (block_n * weightKOffset); + tileMemcpyLoad(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); + tileMemcpyLoad(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); - half* B_warp_tile = B_block_smem + (warp_n * WN); + half* B_warp_tile = B_block_smem + (warp_n * WN * BK); ldmatrix_a(A_warp_tile, A_register_); - ldmatrix_b(B_warp_tile, B_register_); + ldmatrix_b(B_warp_tile, B_register_); // outer product between mma tiles #pragma unroll @@ -1097,47 +1118,47 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; offset_direction = -1 * offset_direction; - tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); - tileMemcpySwizzleStoreB (B_gmem_cache_reg, B_block_smem); + tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); } } ////////////// // epilogue // ////////////// - half alpha_ = (half)alpha; - half beta_ = (half)beta; - half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; +// half alpha_ = (half)alpha; +// half beta_ = (half)beta; +// half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]; - // calculate pointers for this warps C and D tiles - half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); - half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); - half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); - half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); +// // calculate pointers for this warps C and D tiles +// half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); +// half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); +// half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim); +// half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim); - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) - { - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) - { - half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); - ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); +// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) +// { +// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) +// { +// half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); +// ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half)); - // scale C by beta - acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; - acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; - acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; - acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; - } - } +// // scale C by beta +// acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_; +// acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_; +// acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_; +// acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_; +// } +// } - for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) - { - for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) - { - half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); - stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); - } - } +// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) +// { +// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) +// { +// half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim); +// stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half)); +// } +// } } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 9c15d72c8f..1a54a184a8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -190,15 +190,176 @@ template -__device__ __forceinline__ void tileMemcpyLoad( +__device__ __forceinline__ void tileMemcpyLoadA( half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], - const unsigned int src_stride + // const unsigned int src_stride, + const unsigned int block_k, + const unsigned int inChannelOffset, + param_t param ) { // reinterpret input/output as float4 float4* src_float4 = reinterpret_cast(src); - const unsigned int src_stride_vectorized = src_stride / 8; + // const unsigned int src_stride_vectorized = src_stride / 8; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + // dst_reg[i] = src_float4[src_index]; + // thread_row += ROW_STEP; + unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + unsigned int inOffset = n * param.c * param.h * param.w; + // TODO: next block_k loop + const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + int curH = posh_ori + curR * param.d_h; // input h + int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + curR < param.R && curS < param.S && curC < param.c){ + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + } else{ + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +} + + +template +__device__ __forceinline__ void tileMemcpyLoadB( + half* src, + float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int block_k, + const unsigned int src_stride, + param_t param +) +{ + // reinterpret input/output as float4 + float4* src_float4 = reinterpret_cast(src); + // const unsigned int src_stride_vectorized = src_stride / 8; + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + // dst_reg[i] = src_float4[src_index]; + // thread_row += ROW_STEP; + const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; + if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + dst_reg[i] = reinterpret_cast(&src[src_index])[0]; + }else{ // read 4 halves + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + thread_row += ROW_STEP; + } +} + +// template +// __device__ __forceinline__ void tileMemcpySwizzleStoreB( +// float4 src_reg[ELEMENTS_PER_THREAD], +// half* dst +// ) +// { +// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; + +// // reinterpret input/output as float4 +// float4* dst_float4 = reinterpret_cast(dst); + +// // # of threads is multiple of # of columns in the tile +// constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; +// static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + +// // flatten out 2d grid of threads into in order of increasing threadIdx.x +// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + +// // assign each thread a row/column in the tile, calculate how many iterations we need +// // to cover the whole tile +// constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; +// constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; +// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; +// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + +// // compile time check that we provided the right amount of registers for storage +// static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + +// #pragma unroll +// for (unsigned int i = 0; i < NUM_ITERS; i++) +// { +// // apply swizzle to the dst index +// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; +// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); +// dst_float4[dst_index] = src_reg[i]; +// thread_row += ROW_STEP; +// } +// } + +// same as above but without the swizzle +template +__device__ __forceinline__ void tileMemcpyStore( + float4 src_reg[ELEMENTS_PER_THREAD], + half* dst, + unsigned int dst_stride_float4 +) +{ + // reinterpret input/output as float4 + float4* dst_float4 = reinterpret_cast(dst); // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; @@ -220,11 +381,72 @@ __device__ __forceinline__ void tileMemcpyLoad( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { - const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; - dst_reg[i] = src_float4[src_index]; + // apply swizzle to the dst index + unsigned int dst_index = thread_row * dst_stride_float4 + thread_col; + dst_float4[dst_index] = src_reg[i]; thread_row += ROW_STEP; } } + +// this is a special case of the above for when TILE_COLS == 32 +template +__device__ __forceinline__ void tileMemcpySwizzleStore( + const float4 (&src_reg)[ELEMENTS_PER_THREAD], + half* dst +) +{ + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; + + // reinterpret input/output as float4 + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // flatten out 2d grid of threads into in order of increasing threadIdx.x + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) + { + // apply swizzle to the dst index + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + dst_float4[dst_index] = src_reg[i]; + thread_row += ROW_STEP; + } +} + +__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) { + uint32_t address; + asm("{\n\t" + " .reg .u64 u64addr;\n\t" + " cvta.to.shared.u64 u64addr, %1;\n\t" + " cvt.u32.u64 %0, u64addr;\n\t" + "}" + : "=r"(address) + : "l"(pointer)); + return address; +} + #endif #define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256