make CI happy

This commit is contained in:
bssrdf 2025-10-29 23:23:21 -04:00
parent 2b5351a898
commit c1f67c19e0
2 changed files with 16 additions and 15 deletions

View File

@ -85,8 +85,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
// Warp tile
const uint lane_id = tx % WARPSIZE;
const uint warp_id = tx / WARPSIZE;
const int mma_tid_x = warp_id / (BN / WN);
const int mma_tid_y = warp_id % (BN / WN);
const int mma_tid_x = warp_id / (BN / WN);
const int mma_tid_y = warp_id % (BN / WN);
// size of the warp subtile
constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER);
@ -449,7 +449,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const int n = (ksplit > 0) ? gemm_i / PQ : z;
const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i;
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
const uint outOffset = ksplit > 0 ?
const uint outOffset = ksplit > 0 ?
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
row * param.Oh * param.Ow + col :
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
@ -626,7 +626,7 @@ __device__ __forceinline__ void ldmatrix_b(
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
uint32_t (&reg_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
@ -739,11 +739,11 @@ constexpr unsigned int MMA_N = 8;
constexpr int BUFFER_SIZE = BM * BK + BK * BN;
// declare register storage
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
// convenience cast to half for register storage
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(A_register);
@ -827,7 +827,7 @@ constexpr unsigned int MMA_N = 8;
// reuse smem
half *smemoutput = shmem;
const uint lane_id = threadIdx.x % WARPSIZE;
const uint lane_id = threadIdx.x % WARPSIZE;
const uint mma_row = lane_id / 4;
const uint mma_col = lane_id % 4;
const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2;
@ -845,7 +845,7 @@ constexpr unsigned int MMA_N = 8;
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
{
uint32_t (&reg_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
uint idx = output_sts_addr +
uint idx = output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
idx = idx ^ ((idx & 0b1110000000) >> 4);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
@ -902,7 +902,7 @@ constexpr static int conv_shapes[][NUM_VARIANTS] = {
};
template <typename T, unsigned int CONV_SHAPE>
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
const uint BM = conv_shapes[0][CONV_SHAPE];
const uint BN = conv_shapes[1][CONV_SHAPE];
@ -920,7 +920,7 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
int threadz = 1; // threadz number per block
dim3 thblock(NUM_THREADS, thready, threadz);
dim3 grid(blockx, blocky, blockz);
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
}
@ -991,6 +991,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
conv2d_implicit_cuda<float, 1>(X_D, K_D, Y_D, P, st);
GGML_UNUSED(ctx);
GGML_UNUSED(cc);
}
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -137,7 +137,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
@ -199,7 +199,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
@ -215,7 +215,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
GGML_UNUSED(inChannelOffset);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
#endif
}
@ -299,7 +299,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore(
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
@ -312,7 +312,7 @@ __device__ __forceinline__ void tileMemcpySwizzleStore(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++)
{