This commit is contained in:
bssrdf 2025-10-21 17:12:50 -04:00
parent f0a480cc22
commit f931ad883f
2 changed files with 141 additions and 5 deletions

View File

@ -730,6 +730,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
}
}
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
template<const int BM, const int BN, const int BK, const int WM, const int WN,
const int WK, const int NUM_THREADS>
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
@ -793,16 +795,16 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// these register arrays are used to cache values pre-fetched from global memory during the inner loop of the kernel
// the code is nicer if we hard code it for these tile dimensions and number of threads
// since we performing this copy with float4 pointers, for these tile dimensions it works out to be 8 float4s for A and 4 float4s for B
static_assert(BM_dim == 256);
static_assert(BN_dim == 256);
static_assert(BK_dim == 32);
static_assert(BM == 256);
static_assert(BN == 256);
static_assert(BK == 32);
static_assert(NUM_THREADS == 256);
float4 A_gmem_cache_reg[4];
float4 B_gmem_cache_reg[4];
// prefetch the first block tile of A,B into shared memory
half* A_block_gmem = A + (block_m * BM_dim * A_stride);
half* B_block_gmem = B + (block_n * BN_dim);
half* A_block_gmem = input + (block_m * BM * A_stride);
half* B_block_gmem = weight + (block_n * BN);
tileMemcpySwizzleA<BM_dim, NUM_THREADS>(A_block_gmem, A_block_smem, K);
tileMemcpySwizzle<BK_dim, BN_dim, NUM_THREADS, SWIZZLE_BITS_B>(B_block_gmem, B_block_smem, N);
@ -905,6 +907,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
#endif
#define NUM_VARIANTS 6

View File

@ -25,6 +25,139 @@ typedef struct{
uint3 S_fastdiv;
} param_t;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
template<unsigned int TILE_ROWS,
unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int SWIZZLE_BITS>
__device__ __forceinline__ void tileMemcpySwizzle(
half* src,
half* dst,
const unsigned int src_stride
)
{
constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS;
// reinterpret input/output as float4
float4* src_float4 = reinterpret_cast<float4*>(src);
float4* dst_float4 = reinterpret_cast<float4*>(dst);
const unsigned int src_stride_vectorized = src_stride / 8;
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++)
{
// apply swizzle to the dst index
const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS);
dst_float4[dst_index] = src_float4[src_index];
thread_row += ROW_STEP;
}
}
// this is a special case of the above for when TILE_COLS == 32
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleA(
half* src,
half* dst,
const unsigned int src_stride
)
{
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
constexpr unsigned int SWIZZLE_BITS_1 = 4;
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
constexpr unsigned int SWIZZLE_BITS_2 = 2;
constexpr unsigned int TILE_COLS = 32;
// reinterpret input/output as float4
float4* src_float4 = reinterpret_cast<float4*>(src);
float4* dst_float4 = reinterpret_cast<float4*>(dst);
const unsigned int src_stride_vectorized = src_stride / 8;
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++)
{
// apply swizzle to the dst index
const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
dst_float4[dst_index] = src_float4[src_index];
thread_row += ROW_STEP;
}
}
template<unsigned int TILE_ROWS,
unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoad(
half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
const unsigned int src_stride
)
{
// reinterpret input/output as float4
float4* src_float4 = reinterpret_cast<float4*>(src);
const unsigned int src_stride_vectorized = src_stride / 8;
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++)
{
const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
dst_reg[i] = src_float4[src_index];
thread_row += ROW_STEP;
}
}
#endif
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);