reorder register tile loop

This commit is contained in:
bssrdf 2025-10-08 13:52:56 -04:00
parent c6255442bb
commit 0ca43582e8
1 changed files with 16 additions and 8 deletions

View File

@ -183,12 +183,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
}
// #pragma unroll
// for (int i = 0; i < 8; ++i){
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
// #pragma unroll
// for (int j = 0; j < 8; ++j){
// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
// }
// }
#pragma unroll
for (int i = 0; i < 8; ++i){
auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
for (int j = 0; j < 8; ++j){
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
#pragma unroll
for (int j = 0; j < 8; ++j){
output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
for (int i = 0; i < 8; ++i){
output_frag[j][i] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
}
}
}
@ -215,7 +223,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
for (int i = 0; i < 8; ++i){
#pragma unroll
for (int j = 0; j < 8; ++j){
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][j]) * input_frag[1][i];
}
}
}
@ -240,15 +248,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
#pragma unroll
for (int subj = 0; subj < 4; ++subj){
// output sts
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj];
}
}
__syncthreads();
#pragma unroll
for (int subk = 0; subk < 16; ++subk){
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32;
if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow)
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
}
}