CUDA: Add `fastdiv` to `k_bin_bcast*`, giving 1-3% E2E performance (#15872)

* Add fastdiv and fastmodulo to k_bin_bcast kernel

* Address review comments

* `prod_` instead of `prod` suffix

* Add test case for `k_bin_bcast_unravel` in CUDA backend
This commit is contained in:
Oliver Simons 2025-09-10 22:04:03 +02:00 committed by GitHub
parent 4f658855fa
commit 00681dfc16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 111 additions and 74 deletions

View File

@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b; return a / b;
} }
template <float (*bin_op)(const float, const float),
typename src0_t,
typename src1_t,
typename dst_t,
typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0,
const src1_t * src1,
dst_t * dst,
const int ne0,
const int ne1,
const int ne2,
const uint3 ne3,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
const int ne0, const int ne1, const int ne2, const int ne3,
const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return; return;
} }
const int i11 = i1 % ne11; const uint32_t i11 = fastmodulo(i1, ne11);
const int i12 = i2 % ne12; const uint32_t i12 = fastmodulo(i2, ne12);
const int i13 = i3 % ne13; const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const int i10 = i0 % ne10; const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) { if constexpr (sizeof...(src1_ptrs) > 0) {
@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
} }
} }
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs> template <float (*bin_op)(const float, const float),
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, typename src0_t,
const int ne0, const int ne1, const int ne2,const int ne3, typename src1_t,
const int ne10, const int ne11, const int ne12, const int ne13, typename dst_t,
/*int s0, */ const int s1, const int s2, const int s3, typename... src1_ptrs>
/*int s00,*/ const int s01, const int s02, const int s03, static __global__ void k_bin_bcast_unravel(const src0_t * src0,
/*int s10,*/ const int s11, const int s12, const int s13, const src1_t * src1,
src1_ptrs ... src1s) { dst_t * dst,
const uint3 ne0,
const uint3 ne1,
const uint3 ne2,
const uint32_t ne3,
const uint3 prod_012,
const uint3 prod_01,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0); const uint32_t i3 = fastdiv(i, prod_012);
const int i2 = (i/(ne1*ne0)) % ne2; const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
const int i1 = (i/ne0) % ne1; const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
const int i0 = i % ne0; const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return; return;
} }
const int i11 = i1 % ne11; const int i11 = fastmodulo(i1, ne11);
const int i12 = i2 % ne12; const int i12 = fastmodulo(i2, ne12);
const int i13 = i3 % ne13; const int i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10; const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) { if constexpr (sizeof...(src1_ptrs) > 0) {
@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb[0]; size_t nb0 = cnb[0];
size_t nb1 = cnb[1]; size_t nb1 = cnb[1];
size_t nb2 = cnb[2]; size_t nb2 = cnb[2];
@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x); block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z); (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
if (block_nums.z > 65535) { if (block_nums.z > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
if constexpr (sizeof...(I) > 0) { if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne0, ne1, ne2, ne3, ne12, ne13,
ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3,
/* s0, */ s1, s2, s3, /* s00,*/ s01, s02, s03,
/* s00,*/ s01, s02, s03, /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else { } else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne0, ne1, ne2, ne3, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3,
/* s0, */ s1, s2, s3, /* s00,*/ s01, s02, s03,
/* s00,*/ s01, s02, s03, /* s10,*/ s11, s12, s13);
/* s10,*/ s11, s12,s13);
} }
} else { } else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) { if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t> k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3, /* s0, */ s1, s2, s3,
ne10, ne11, ne12, ne13, /* s00,*/ s01, s02, s03,
/* s0, */ s1, s2, s3, /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else { } else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t> k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3, /* s0, */ s1, s2, s3,
ne10, ne11, ne12, ne13, /* s00,*/ s01, s02, s03,
/* s0, */ s1, s2, s3, /* s10,*/ s11, s12, s13);
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
} }
} }
} }

View File

@ -6050,6 +6050,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}); add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}); add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
// test case for k_bin_bcast_unravel in CUDA backend
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
// stable diffusion // stable diffusion
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1}); add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1}); add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});