CUDA: Add `fastdiv` to `k_bin_bcast*`, giving 1-3% E2E performance (#15872)

* Add fastdiv and fastmodulo to k_bin_bcast kernel

* Address review comments

* `prod_` instead of `prod` suffix

* Add test case for `k_bin_bcast_unravel` in CUDA backend
This commit is contained in:
Oliver Simons 2025-09-10 22:04:03 +02:00 committed by GitHub
parent 4f658855fa
commit 00681dfc16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 111 additions and 74 deletions

View File

@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b; return a / b;
} }
template <float (*bin_op)(const float, const float),
typename src0_t,
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs> typename src1_t,
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, typename dst_t,
const int ne0, const int ne1, const int ne2, const int ne3, typename... src1_ptrs>
const int ne10, const int ne11, const int ne12, const int ne13, static __global__ void k_bin_bcast(const src0_t * src0,
/*int s0, */ const int s1, const int s2, const int s3, const src1_t * src1,
/*int s00,*/ const int s01, const int s02, const int s03, dst_t * dst,
/*int s10,*/ const int s11, const int s12, const int s13, const int ne0,
const int ne1,
const int ne2,
const uint3 ne3,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) { src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x; const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
return; return;
} }
const int i11 = i1 % ne11; const uint32_t i11 = fastmodulo(i1, ne11);
const int i12 = i2 % ne12; const uint32_t i12 = fastmodulo(i2, ne12);
const int i13 = i3 % ne13; const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const int i10 = i0 % ne10; const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) { if constexpr (sizeof...(src1_ptrs) > 0) {
@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
} }
} }
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs> template <float (*bin_op)(const float, const float),
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, typename src0_t,
const int ne0, const int ne1, const int ne2,const int ne3, typename src1_t,
const int ne10, const int ne11, const int ne12, const int ne13, typename dst_t,
/*int s0, */ const int s1, const int s2, const int s3, typename... src1_ptrs>
/*int s00,*/ const int s01, const int s02, const int s03, static __global__ void k_bin_bcast_unravel(const src0_t * src0,
/*int s10,*/ const int s11, const int s12, const int s13, const src1_t * src1,
src1_ptrs ... src1s) { dst_t * dst,
const uint3 ne0,
const uint3 ne1,
const uint3 ne2,
const uint32_t ne3,
const uint3 prod_012,
const uint3 prod_01,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0); const uint32_t i3 = fastdiv(i, prod_012);
const int i2 = (i/(ne1*ne0)) % ne2; const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
const int i1 = (i/ne0) % ne1; const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
const int i0 = i % ne0; const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return; return;
} }
const int i11 = i1 % ne11; const int i11 = fastmodulo(i1, ne11);
const int i12 = i2 % ne12; const int i12 = fastmodulo(i2, ne12);
const int i13 = i3 % ne13; const int i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10; const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) { if constexpr (sizeof...(src1_ptrs) > 0) {
@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb[0]; size_t nb0 = cnb[0];
size_t nb1 = cnb[1]; size_t nb1 = cnb[1];
size_t nb2 = cnb[2]; size_t nb2 = cnb[2];
@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x); block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z); (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
if (block_nums.z > 65535) { if (block_nums.z > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
if constexpr (sizeof...(I) > 0) { if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne0, ne1, ne2, ne3, ne12, ne13,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3, /* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03, /* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13, /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
(const src1_t *) dst->src[I + 1]->data...);
} else { } else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne0, ne1, ne2, ne3, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3, /* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03, /* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13); /* s10,*/ s11, s12, s13);
} }
} else { } else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) { if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t> k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3, /* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03, /* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13, /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
(const src1_t *) dst->src[I + 1]->data...);
} else { } else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t> k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3, /* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03, /* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13); /* s10,*/ s11, s12, s13);
} }
} }
} }

View File

@ -6050,6 +6050,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}); add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}); add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
// test case for k_bin_bcast_unravel in CUDA backend
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
// stable diffusion // stable diffusion
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1}); add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1}); add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});