WIP: fixed another bug

This commit is contained in:
bssrdf 2025-10-25 20:24:14 -04:00
parent 396f55831c
commit 475f9879c5
2 changed files with 54 additions and 22 deletions

View File

@ -931,7 +931,8 @@ __device__ __forceinline__ void ldmatrix_b(
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
// : "r"(src_addr ^ 0b1000000)
@ -941,14 +942,16 @@ __device__ __forceinline__ void ldmatrix_b(
src_addr ^= 0b10000;
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
: "r"(src_addr)
);
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
// : "r"(src_addr ^ 0b1000000)
@ -959,14 +962,16 @@ __device__ __forceinline__ void ldmatrix_b(
src_addr ^= 0b110000;
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
: "r"(src_addr)
);
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
// : "r"(src_addr ^ 0b1000000)
@ -976,14 +981,16 @@ __device__ __forceinline__ void ldmatrix_b(
src_addr ^= 0b10000;
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
: "r"(src_addr)
);
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
// : "r"(src_addr ^ 0b1000000)
@ -1043,6 +1050,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// declare register storage
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
// float acc_register_[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
@ -1131,16 +1139,40 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
"r"(B_register[mma_k][mma_n])
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
);
// asm volatile (
// "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
// "{%0, %1, %2, %3},"
// "{%4, %5},"
// "{%6},"
// "{%7, %8, %9, %10};\n"
// : "=f"(acc_register_[mma_m][mma_n][0]), "=f"(acc_register_[mma_m][mma_n][1]),
// "=f"(acc_register_[mma_m][mma_n][2]), "=f"(acc_register_[mma_m][mma_n][3])
// : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
// "r"(B_register[mma_k][mma_n]),
// "f"(acc_register_[mma_m][mma_n][0]), "f"(acc_register_[mma_m][mma_n][1]),
// "f"(acc_register_[mma_m][mma_n][2]), "f"(acc_register_[mma_m][mma_n][3])
// );
}
}
if(threadIdx.x == 28 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]),
__half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3]));
printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]),
__half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3]));
printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]),
__half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3]));
}
// if(threadIdx.x == 12 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(acc_register_[0][0][0]), __half2float(acc_register_[0][0][1]),
// __half2float(acc_register_[0][0][2]), __half2float(acc_register_[0][0][3]));
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[0][0][0], acc_register_[0][0][1],
// acc_register_[0][0][2], acc_register_[0][0][3]);
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[0][mma_k][0]), __half2float(A_register_[0][mma_k][1]),
// __half2float(A_register_[0][mma_k][2]), __half2float(A_register_[0][mma_k][3]));
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]),
// __half2float(B_register_[mma_k][0][2]), __half2float(B_register_[mma_k][0][3]));
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[1][0][0], acc_register_[1][0][1],
// acc_register_[1][0][2], acc_register_[1][0][3]);
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[1][mma_k][0]), __half2float(A_register_[1][mma_k][1]),
// __half2float(A_register_[1][mma_k][2]), __half2float(A_register_[1][mma_k][3]));
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, acc_register_[3][0][0], acc_register_[3][0][1],
// acc_register_[3][0][2], acc_register_[3][0][3]);
// printf(" %d, %d: %f, %f, %f, %f \n", block_k, mma_k, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]),
// __half2float(A_register_[3][mma_k][2]), __half2float(A_register_[3][mma_k][3]));
// printf(" %d, %d: %f, %f, \n", block_k, mma_k, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));
// }
// if(threadIdx.x < 4 && threadIdx.y ==0 && blockIdx.x ==0 && blockIdx.y ==0){
// printf("A %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(A_register_[3][mma_k][0]), __half2float(A_register_[3][mma_k][1]));
// printf("B %d, %d, %d: %f, %f \n", block_k, mma_k, threadIdx.x, __half2float(B_register_[mma_k][0][0]), __half2float(B_register_[mma_k][0][1]));

View File

@ -50,10 +50,10 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu
std::vector<float> adata(KW * KH * IC * OC);
for (int i = 0; i < KW * KH * IC * OC; i++) {
// adata[i] = 2.f;
adata[i] = (float)(i%KW)-1.f;
// adata[i] = (float)(i%KW)-1.f;
// adata[i] = (rand() % 255) / 255.0;
// float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
// adata[i] = r;
float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
adata[i] = r;
}
// Convert adata to fp16 format
@ -63,11 +63,11 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu
// Initialize bdata
std::vector<float> bdata(IW * IH * IC * N);
for (int i = 0; i < IW * IH * IC * N; i++) {
bdata[i] = (float)(i%IW)/10.f;
// bdata[i] = (float)(i%IW)/10.f;
// bdata[i] = 1.5f;
// bdata[i] = (rand() % 255) / 255.0;
// float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
// bdata[i] = r;
float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
bdata[i] = r;
}
size_t buffer_size = 0;
@ -452,7 +452,7 @@ int main(void)
float diff = fabs(im2col_data[i] - wino_data[i]);
float diff1 = fabs(im2col_data[i] - conv2d_data[i]);
// if(diff > 1.e-4) {
printf("(%f, %f, %f, %f, %f, %d) \n",
printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n",
im2col_data[i], conv2d_data[i],
wino_data[i], diff, diff1, i);
// break;