trying to reduce integer ops; simply code

This commit is contained in:
bssrdf 2025-11-12 11:32:27 -05:00
parent c33e4301dc
commit ea438d8b0e
2 changed files with 27 additions and 37 deletions

View File

@ -805,6 +805,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
}
const unsigned int A_warp_tile_offset = warp_m * WM * BK;
const unsigned int B_warp_tile_offset = warp_n * WN * BK;
static_assert(BM == 256);
static_assert(BN == 256);
static_assert(BK == 32);
@ -825,13 +828,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
__syncthreads();
if (block_k != num_block_tiles_k){
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param);
}
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
half* A_warp_tile = A_block_smem + A_warp_tile_offset;
half* B_warp_tile = B_block_smem + B_warp_tile_offset;
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK>(A_warp_tile, A_register_);
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BK>(B_warp_tile, B_register_);
@ -886,23 +887,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const uint lane_id = threadIdx.x % WARPSIZE;
const uint mma_row = lane_id / 4;
const uint mma_col = lane_id % 4;
const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2;
const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2;
const uint warp_offset = warp_m * WM * BN/2 + warp_n * WN/2;
const uint output_lds_addr = warp_offset + lane_id * BN/2;
const uint output_sts_addr = warp_offset + mma_row * BN/2 + mma_col * 2;
const uint m_idx = block_n * BN + warp_n * WN;
const uint n_idx = block_m * BM + warp_m * WM + lane_id;
#pragma unroll
for (int i = 0; i < 2; ++i)
{
const unsigned int i_offset = i * mma_tiles_per_warp_n/2;
__syncthreads();
#pragma unroll
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
const unsigned int mma_m_offset = output_sts_addr + mma_m * MMA_M * BN / 2;
for (unsigned int mma_n = i_offset; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
{
uint32_t (&reg_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
uint idx = output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
uint idx = mma_m_offset + (mma_n - i_offset) * MMA_N;
idx = idx ^ ((idx & 0b110000000000) >> 9);
idx = idx ^ ((idx & 0b1110000000) >> 4);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
@ -913,6 +916,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
}
__syncthreads();
const unsigned int m_i_wn = m_idx + i * WN / 2;
#pragma unroll
for (int subk = 0; subk < WN / 4; ++subk){
@ -925,29 +929,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const uint gemm_i = n_idx + j*32;
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*32*BN/2]));
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*16*BN])); // 32*BN/2 = 16*BN
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
if (n < param.n && row < param.k && col < PQ) {
if constexpr (ksplit > 0) {
const uint outOffset = z * NKPQ +
n * KPQ +
row * PQ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
} else {
const uint outOffset = n * KPQ + row * PQ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
}
const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + row * PQ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
}
if (n < param.n && row+1 < param.k && col < PQ) {
if constexpr (ksplit > 0) {
const uint outOffset = z * NKPQ +
n * KPQ +
(row+1) * PQ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
} else {
const uint outOffset = n * KPQ + (row+1) * PQ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
}
const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + (row+1) * PQ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
}
}
}

View File

@ -714,15 +714,15 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
// im2col_data[i], conv2d_data[i],
// diff, i);
// // break;
// // }
// }
for(int i = 0; i < conv2d_data.size(); i++) {
float diff = fabs(im2col_data[i] - conv2d_data[i]);
// if(diff > 0.5) {
printf("(%7.3f, %7.3f, %.2f, %d) \n",
im2col_data[i], conv2d_data[i],
diff, i);
// break;
// }
}
ggml_free(model.ctx);
ggml_backend_buffer_free(model.buffer);