various small optimizations

This commit is contained in:
bssrdf 2025-11-14 13:51:07 -05:00
parent ecbbdb6608
commit e4fbece606
2 changed files with 37 additions and 12 deletions

View File

@ -811,7 +811,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8;
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int A_K_STRID = BM / ROW_STEP;
constexpr unsigned int B_K_STRID = BN / ROW_STEP;
// constexpr unsigned int B_K_STRID = BN / ROW_STEP;
unsigned int masks_a[A_K_STRID][2];
int64_t element_offset_a[A_K_STRID];
@ -1263,9 +1263,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
if (BlocksM * BlocksN < 2*(unsigned int)nsm){
int j, max_remaining_waves = -1, candidate = -1;
int ks = min(16, nsm / (BlocksM * BlocksN));
int ks = min(20, nsm / (BlocksM * BlocksN));
if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5)
ks = 16;
ks = 20;
for (j = 2; j <= ks; j++){
const int remainder = (BlocksM * BlocksN * j) % nsm;
// if ((P.c * P.r * P.s) % (8*j) == 0){
@ -1328,7 +1328,20 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
} else if (j == 16) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 16,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 17) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 17,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 18) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 18,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 19) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 19,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 20) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 20,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
}
return;
}
}
@ -1395,7 +1408,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
const uint B = input->ne[3]; // n_batches
int64_t pp[3];
int64_t pp[3] = {0};
// const unsigned int K = param.c;
// const uint inChannelOffset = param.c * param.w;
// const uint weightKOffset = param.c * param.r * param.s;

View File

@ -37,7 +37,8 @@ typedef struct{
/// Clears the predicates
template<const unsigned int K_STRID>
__host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) {
// __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) {
__device__ void clear_mask(unsigned int masks_[][2], bool clear = true) {
#pragma unroll
for (int s = 0; s < K_STRID; ++s) {
@ -47,7 +48,8 @@ __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true)
}
template<const unsigned int K_STRID>
__host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){
// __host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){
__device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){
#pragma unroll
for (int s = 0; s < K_STRID; ++s) {
element_offset[s] += offset;
@ -66,7 +68,7 @@ __device__ void prepareIteratorA(unsigned int thread_row,
int offset_p[A_K_STRID];
int offset_q[A_K_STRID];
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
// constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
// const unsigned int chw = param.c * param.h * param.w;
@ -436,15 +438,22 @@ __device__ __forceinline__ void tileMemcpyLoadB(
const unsigned int curC = start_k+block_k+thread_col*8;
const unsigned int ki = (curR*param.s+curS)*param.c + curC;
unsigned int iter_idx = thread_row * param.weightKOffset + ki;
unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS;
const int ITER_STEPS = ROW_STEP * param.weightKOffset;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * param.weightKOffset + ki;
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
// const unsigned int src_index = thread_row * param.weightKOffset + ki;
const unsigned int src_index = iter_idx;
// if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
if (krow_idx < param.k && curC < end_k){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
krow_idx += ROW_STEP;
iter_idx += ITER_STEPS;
}
#else
GGML_UNUSED(src);
@ -492,21 +501,24 @@ __device__ __forceinline__ void tileMemcpySwizzleStore(
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED;
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++)
{
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
unsigned int dst_index = iter_idx;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
dst_float4[dst_index] = src_reg[i];
thread_row += ROW_STEP;
iter_idx += ITER_STEPS;
}
#else
GGML_UNUSED(src_reg);