WIP: tensore code compiled ok

This commit is contained in:
bssrdf 2025-10-24 11:41:11 -04:00
parent 2715341c1d
commit 80a996cfc0
2 changed files with 126 additions and 91 deletions

View File

@ -730,7 +730,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
}
}
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
__device__ __forceinline__ void ldmatrix_a(
@ -738,6 +738,7 @@ __device__ __forceinline__ void ldmatrix_a(
half (&reg)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
)
{
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4");
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
@ -880,7 +881,11 @@ __device__ __forceinline__ void ldmatrix_a(
: "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1])
: "r"(src_addr + 96 * smem_stride_)
);
#else
GGML_UNUSED(src);
GGML_UNUSED(reg);
NO_DEVICE_CODE;
#endif
}
template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n, unsigned int smem_stride>
@ -889,10 +894,11 @@ __device__ __forceinline__ void ldmatrix_b(
half (&reg)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
)
{
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
// uint32_t (&reg_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
uint32_t (&reg_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
// const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8);
// unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5);
// uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
@ -979,15 +985,20 @@ __device__ __forceinline__ void ldmatrix_b(
// : "r"(src_addr ^ 0b1000000)
: "r"(src_addr + 32 * smem_stride_)
);
#else
GGML_UNUSED(src);
GGML_UNUSED(reg);
NO_DEVICE_CODE;
#endif
}
template<const int BM, const int BN, const int BK, const int WM, const int WN,
const int WK, const int NUM_THREADS>
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
static __global__ void conv2d_implicit_kernel_tc(const half * __restrict__ input,
const half * __restrict__ kernel,
half * __restrict__ output,
const param_t param) {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
constexpr unsigned int MMA_M = 16;
constexpr unsigned int MMA_N = 8;
@ -997,18 +1008,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const uint weightKOffset = param.c * param.r * param.s;
// for convenience/readability in index calculations
const unsigned int A_stride = K;
const unsigned int B_stride = N;
const unsigned int CD_stride = N;
// const unsigned int A_stride = K;
// const unsigned int B_stride = N;
// const unsigned int CD_stride = N;
// calculate how many bits of shared memory indices are going to be swizzled, and create masks
constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8);
// constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8);
// loop bounds, constexpr where possible allows for loop unrolling
constexpr unsigned int mma_tiles_per_warp_k = 4;
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
const unsigned int num_block_tiles_k = K / BK;
const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
// calculate block/warp indices
const unsigned int block_m = blockIdx.y;
@ -1057,8 +1068,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// prefetch the first block tile of A,B into shared memory
// half* A_block_gmem = input + (block_m * BM * A_stride);
half* A_block_gmem = input;
half* B_block_gmem = kernel + (block_n * weightKOffset);
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + (block_n * weightKOffset);
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
@ -1074,10 +1085,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
if (block_k != num_block_tiles_k)
{
// half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
half* A_block_gmem = input;
half* B_block_gmem = kernel + (block_n * weightKOffset);
tileMemcpyLoad<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
tileMemcpyLoad<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + (block_n * weightKOffset);
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
}
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
@ -1124,14 +1135,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
// reuse smem
half *smemoutput = reinterpret_cast<half *>(shmem);
half *smemoutput = shmem;
const uint lane_id = threadIdx.x % WARPSIZE;
const uint mma_row = lane_id / 4;
const uint mma_col = lane_id % 4;
const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id;
const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2;
const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2;
const uint m_idx = by * BN + mma_tid_y * WN;
const uint n_idx = block_m * BM + warp_m * WM;
const uint m_idx = block_n * BN + warp_n * WN;
const uint n_idx = block_m * BM + warp_m * WM + lane_id;
#pragma unroll
for (int i = 0; i < 2; ++i)
@ -1142,72 +1153,42 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
{
// output sts
uint32_t (&reg_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]);
const uint idx = output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
dst_ptr[0] = reg_[0];
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]);
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx + 8 * BN / 2]);
dst_ptr[0] = reg_[1];
}
}
__syncthreads();
#pragma unroll
const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM;
const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM;
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
const uint outOffset = ksplit > 0 ?
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
row * param.Oh * param.Ow + col :
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
for (int subk = 0; subk < WN / 2; ++subk){
for (int j = 0; j < 4; ++j){
const uint row = m_idx + subk + i * WN / 2;
const uint gemm_i = n_idx + j*32;
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
if(n < param.n && row < param.k && col < param.Oh * param.Ow){
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2];
}
}
}
}
//////////////
// epilogue //
//////////////
// half alpha_ = (half)alpha;
// half beta_ = (half)beta;
// half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
// // calculate pointers for this warps C and D tiles
// half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
// half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
// half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
// half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
// {
// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
// {
// half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
// ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half));
// // scale C by beta
// acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_;
// acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_;
// acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_;
// acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_;
// }
// }
// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
// {
// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
// {
// half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
// stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half));
// }
// }
#else
GGML_UNUSED(input);
GGML_UNUSED(kernel);
GGML_UNUSED(output);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}
#endif
#define NUM_VARIANTS 6
@ -1266,11 +1247,53 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
}
}
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.layout == 0 && P.c % 8 == 0) {
constexpr unsigned int BM_dim = 256;
constexpr unsigned int BN_dim = 256;
constexpr unsigned int BK_dim = 32;
constexpr unsigned int WARPS_PER_BLOCK_M = 2;
constexpr unsigned int WARPS_PER_BLOCK_N = 4;
constexpr unsigned int WARPS_PER_BLOCK_K = 4;
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
int id = ggml_cuda_get_device();
ggml_cuda_pool_alloc<half> x_f16(ctx.pool(id));
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(GGML_TYPE_F32);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = P.c * P.h * P.w * P.n;
x_f16.alloc(ne);
to_fp16_cuda(X_D, x_f16.get(), ne, st);
const half *X_H = x_f16.get();
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
conv2d_implicit_kernel_tc<BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_D, Y_H.get(), P);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
}else{
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
}
#else
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
#endif
}
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<float, 1>(X_D, K_D, Y_D, P, st);
}
@ -1286,6 +1309,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
cudaStream_t st = ctx.stream();
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int32_t * p = (const int32_t *) dst->op_params;
const int ST_X = p[0]; // stride_x
@ -1333,8 +1357,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
params.layout = LT;
if (kernel->type == GGML_TYPE_F16) {
conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st);
} else {
conv2d_implicit_cuda_f32(X_D, K_D, Y_D, params, st);
conv2d_implicit_cuda_f32(ctx, X_D, K_D, Y_D, cc, params, st);
}
}

View File

@ -31,9 +31,10 @@ typedef struct{
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleB(
half* src,
const half* src,
half* dst,
const unsigned int src_stride
const unsigned int src_stride,
param_t param
)
{
// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS;
@ -109,7 +110,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -123,7 +124,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleA(
half* src,
const half* src,
half* dst,
// const unsigned int src_stride,
const unsigned int inChannelOffset,
@ -175,7 +176,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.R && curS < param.S && curC < param.c){
curR < param.r && curS < param.s && curC < param.c){
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
@ -191,7 +192,7 @@ unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoadA(
half* src,
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
// const unsigned int src_stride,
const unsigned int block_k,
@ -200,7 +201,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
)
{
// reinterpret input/output as float4
float4* src_float4 = reinterpret_cast<float4*>(src);
// const float4* src_float4 = reinterpret_cast<const float4*>(src);
// const unsigned int src_stride_vectorized = src_stride / 8;
// # of threads is multiple of # of columns in the tile
@ -239,7 +240,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.R && curS < param.S && curC < param.c){
curR < param.r && curS < param.s && curC < param.c){
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
@ -256,7 +257,7 @@ unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoadB(
half* src,
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
const unsigned int block_k,
const unsigned int src_stride,
@ -264,7 +265,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
)
{
// reinterpret input/output as float4
float4* src_float4 = reinterpret_cast<float4*>(src);
// const float4* src_float4 = reinterpret_cast<const float4*>(src);
// const unsigned int src_stride_vectorized = src_stride / 8;
// # of threads is multiple of # of columns in the tile
@ -295,7 +296,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
// dst_reg[i] = src_float4[src_index];
// thread_row += ROW_STEP;
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -449,5 +450,15 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
#endif
// constexpr unsigned int int_log2(unsigned int x)
// {
// unsigned int result = 0;
// while (x >>= 1)
// {
// result++;
// }
// return result;
// }
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);