swizzling working, may still have room to optimize
This commit is contained in:
parent
76885c7697
commit
949eca4cba
|
|
@ -677,11 +677,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||
uint idx = output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||
uint idx8 = idx + 8 * BN / 2;
|
||||
idx = idx ^ ((idx & 0b110000000000) >> 9);
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||
dst_ptr[0] = reg_[0];
|
||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx + 8 * BN / 2]);
|
||||
idx8 = idx8 ^ ((idx8 & 0b110000000000) >> 9);
|
||||
idx8 = idx8 ^ ((idx8 & 0b1110000000) >> 4);
|
||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx8]);
|
||||
dst_ptr[0] = reg_[1];
|
||||
}
|
||||
}
|
||||
|
|
@ -698,7 +701,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
uint idx = output_lds_addr + subk*2 + j*32*BN/2;
|
||||
idx = idx ^ ((idx & 0b110000000000) >> 9);
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
// uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx]));
|
||||
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
|
||||
if (n < param.n && row < param.k && col < PQ) {
|
||||
|
|
@ -706,13 +708,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const uint outOffset = z * NKPQ +
|
||||
n * KPQ +
|
||||
row * PQ + col;
|
||||
// output[outOffset] = smemoutput[idx];
|
||||
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
|
||||
output[outOffset] = res_[0];
|
||||
} else {
|
||||
const uint outOffset = n * KPQ + row * PQ + col;
|
||||
// output[outOffset] = smemoutput[idx];
|
||||
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
|
||||
output[outOffset] = res_[0];
|
||||
}
|
||||
}
|
||||
|
|
@ -721,13 +719,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const uint outOffset = z * NKPQ +
|
||||
n * KPQ +
|
||||
(row+1) * PQ + col;
|
||||
// output[outOffset] = smemoutput[idx];
|
||||
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
|
||||
output[outOffset] = res_[1];
|
||||
} else {
|
||||
const uint outOffset = n * KPQ + (row+1) * PQ + col;
|
||||
// output[outOffset] = smemoutput[idx];
|
||||
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
|
||||
output[outOffset] = res_[1];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -705,16 +705,16 @@ int main(void)
|
|||
|
||||
|
||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||
for(int i = 0; i < 26*38; i++) {
|
||||
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// if(diff > 0.5) {
|
||||
printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
im2col_data[i], conv2d_data[i],
|
||||
diff, i);
|
||||
// break;
|
||||
// }
|
||||
}
|
||||
// for(int i = 0; i < 26*38; i++) {
|
||||
// // for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // if(diff > 0.5) {
|
||||
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
// im2col_data[i], conv2d_data[i],
|
||||
// diff, i);
|
||||
// // break;
|
||||
// // }
|
||||
// }
|
||||
|
||||
ggml_free(model.ctx);
|
||||
ggml_backend_buffer_free(model.buffer);
|
||||
|
|
|
|||
Loading…
Reference in New Issue