ggml : handle multiple streams in CUDA GGML_OP_WHERE_ID implementation

This commit is contained in:
Stanisław Szymczyk 2026-03-16 16:56:35 +01:00
parent 6c9d773669
commit cb94b565ad
1 changed files with 15 additions and 14 deletions

View File

@ -2,22 +2,23 @@
static __global__ void where_id_kernel(
const float * src0, const int32_t * src1, float * dst,
int64_t ne10, int64_t ne11, int64_t ne12,
size_t nb1, size_t nb2,
size_t nb01, size_t nb02,
size_t nb11, size_t nb12
int64_t ne10, int64_t ne11, int64_t ne12, int64_t ne13,
size_t nb1, size_t nb2, size_t nb3,
size_t nb01, size_t nb02, size_t nb03,
size_t nb11, size_t nb12, size_t nb13
) {
const int64_t total_blocks = ne11 * ne12;
const int64_t total_blocks = ne11 * ne12 * ne13;
for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) {
const int64_t i1 = block_idx % ne11;
const int64_t i2 = block_idx / ne11;
const int64_t i2 = (block_idx / ne11) % ne12;
const int64_t i3 = block_idx / (ne11 * ne12);
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02);
const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12);
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3);
const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03);
const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12 + i3*nb13);
for (int64_t i0 = threadIdx.x; i0 < ne10; i0 += blockDim.x) {
const int32_t id = src1_row[i0];
@ -64,14 +65,14 @@ void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
int threads = std::min((int) ne20, 768); // ids
int64_t total_blocks = ne21 * ne22;
int64_t total_blocks = ne21 * ne22 * ne23;
int blocks = (int) std::min((int64_t) 65535, total_blocks);
where_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src0_d, src2_d, dst_d,
ne20, ne21, ne22,
nb1, nb2,
nb01, nb02,
nb21, nb22
ne20, ne21, ne22, ne23,
nb1, nb2, nb3,
nb01, nb02, nb03,
nb21, nb22, nb23
);
}