further reduce repeated index comutations

This commit is contained in:
bssrdf 2025-11-18 18:36:45 -05:00
parent ba754ce4f3
commit 73444564e6
2 changed files with 20 additions and 7 deletions

View File

@ -871,6 +871,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
prepareIteratorA<BM, BK, A_K_STRID, ROW_STEP>(thread_row, masks_a, element_offset_a, param);
unsigned int iter_src_idx = thread_row * param.weightKOffset;
unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col;
unsigned int krow_idx = thread_row + blockIdx.x * BN;
const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset;
// prefetch the first block tile of A,B into shared memory
@ -923,11 +928,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
if (block_krs != num_block_tiles_krs) {
#ifdef CP_ASYNC_AVAILABLE
curC = tileMemcpyAsyncLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, SA2, r, s,
masks_a, element_offset_a, thread_row, thread_col, block_k * BK,
masks_a, element_offset_a, thread_row, thread_col,
iter_dst_idx, block_k * BK,
start_k, end_k, curC, param);
element_offset_b = (r*param.s+s)*param.c + curC;
tileMemcpyAsyncLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, SB2, r, s, curC, element_offset_b, block_k * BK,
start_k, end_k, thread_row, thread_col, param);
start_k, end_k, thread_row, thread_col,
iter_src_idx, iter_dst_idx, krow_idx, ITER_SRC_STEPS,param);
asm volatile("cp.async.commit_group;\n" ::);
#else
curC = tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s,

View File

@ -343,6 +343,7 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA(
const int64_t element_offset[],
unsigned int thread_row,
const unsigned int thread_col,
unsigned int iter_idx,
const unsigned int block_k,
const unsigned int start_k,
const unsigned int end_k,
@ -369,7 +370,6 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA(
if (curC > oldC)
clear_mask<NUM_ITERS>(masks, curC >= end_k);
unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++) {
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
@ -393,6 +393,7 @@ __device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA(
GGML_UNUSED(element_offset);
GGML_UNUSED(thread_row);
GGML_UNUSED(thread_col);
GGML_UNUSED(iter_idx);
GGML_UNUSED(oldC);
GGML_UNUSED(param);
NO_DEVICE_CODE;
@ -480,6 +481,10 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB(
const unsigned int end_k,
unsigned int thread_row,
const unsigned int thread_col,
unsigned int iter_src_idx,
unsigned int iter_dst_idx,
unsigned int krow_idx,
const int ITER_SRC_STEPS,
param_t param
) {
@ -500,10 +505,7 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
unsigned int iter_src_idx = thread_row * param.weightKOffset + ki;
unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col;
unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS;
const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset;
iter_src_idx += ki;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++) {
@ -529,6 +531,10 @@ __device__ __forceinline__ void tileMemcpyAsyncLoadB(
GGML_UNUSED(end_k);
GGML_UNUSED(thread_row);
GGML_UNUSED(thread_col);
GGML_UNUSED(iter_src_idx);
GGML_UNUSED(iter_dst_idx);
GGML_UNUSED(krow_idx);
GGML_UNUSED(ITER_SRC_STEPS);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif