fixed bug now split-k is working

This commit is contained in:
bssrdf 2025-11-05 13:47:38 -05:00
parent 6f44f47113
commit 688de6d7d8
3 changed files with 5 additions and 8 deletions

View File

@ -819,9 +819,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
}
if (block_k != num_block_tiles_k)
{
// switch smem buffers each iteration

View File

@ -66,7 +66,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
// apply swizzle to the dst index
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8;
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
@ -262,7 +262,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8;
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves

View File

@ -44,7 +44,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3,
// create data
int KW = kw, KH = kh, IC = ic, OC = oc;
int IW = iw, IH = ih, N = 1;
srand(time(NULL));
// srand(time(NULL));
// printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH);
@ -384,8 +384,8 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// // for(int i = 26*38; i < 2*26*38; i++) {
// // for(int i = 0; i < conv2d_data.size(); i++) {
// for(int i = 26*38; i < 2*26*38; i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",