From 9b0a4eea57b2a25268f26971954a2994ca82f0b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 24 Mar 2026 17:25:42 +0100 Subject: [PATCH] ggml : replaced GGML_OP_WHERE_ID with GGML_OP_SCATTER that works similar to torch scatter_ operation. --- ggml/include/ggml.h | 8 ++-- ggml/src/ggml-cpu/ggml-cpu.c | 6 +-- ggml/src/ggml-cpu/ops.cpp | 33 +++++++------- ggml/src/ggml-cpu/ops.h | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 8 ++-- ggml/src/ggml-cuda/scatter.cu | 72 ++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/scatter.cuh | 3 ++ ggml/src/ggml-cuda/where-id.cu | 78 --------------------------------- ggml/src/ggml-cuda/where-id.cuh | 3 -- ggml/src/ggml.c | 21 ++++----- src/llama-graph.cpp | 5 ++- 11 files changed, 117 insertions(+), 122 deletions(-) create mode 100644 ggml/src/ggml-cuda/scatter.cu create mode 100644 ggml/src/ggml-cuda/scatter.cuh delete mode 100644 ggml/src/ggml-cuda/where-id.cu delete mode 100644 ggml/src/ggml-cuda/where-id.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 82186fe8f6..48a5e6ee83 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -558,7 +558,7 @@ extern "C" { GGML_OP_SOLVE_TRI, GGML_OP_GATED_DELTA_NET, GGML_OP_HADAMARD, - GGML_OP_WHERE_ID, + GGML_OP_SCATTER, GGML_OP_UNARY, @@ -2480,11 +2480,11 @@ extern "C" { struct ggml_tensor * a, int n); - GGML_API struct ggml_tensor * ggml_where_id( + GGML_API struct ggml_tensor * ggml_scatter( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * ids); + struct ggml_tensor * ids, + float c); // custom operators diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index e5e5f0507e..7118439b83 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2029,9 +2029,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_hadamard(params, tensor); } break; - case GGML_OP_WHERE_ID: + case GGML_OP_SCATTER: { - ggml_compute_forward_where_id(params, tensor); + ggml_compute_forward_scatter(params, tensor); } break; case GGML_OP_MAP_CUSTOM1: { @@ -2356,7 +2356,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: case GGML_OP_HADAMARD: - case GGML_OP_WHERE_ID: + case GGML_OP_SCATTER: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index c4a77b29e9..d720a6253a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11257,32 +11257,30 @@ void ggml_compute_forward_hadamard( } } -// ggml_compute_forward_where_id +// ggml_compute_forward_scatter -static void ggml_compute_forward_where_id_f32( +static void ggml_compute_forward_scatter_f32( const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * src2 = dst->src[2]; + const float c = ggml_get_op_params_f32(dst, 0); - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(src2->type == GGML_TYPE_I32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(src1->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; const int nr = ggml_nrows(src0); - GGML_TENSOR_TERNARY_OP_LOCALS + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -11301,23 +11299,22 @@ static void ggml_compute_forward_where_id_f32( const int i1 = (ir - i3*ne2*ne1 - i2*ne1); const float * src0_ptr = (float *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 ); - const float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 ); - const int32_t * ids_ptr = (int32_t *) ((char *) src2->data + i3*nb23 + i2*nb22 + i1*nb21); + const int32_t * ids_ptr = (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - // copy whole row from src1 - ggml_vec_cpy_f32(ne00, dst_ptr, src1_ptr); + // copy whole row from src0 + ggml_vec_cpy_f32(ne00, dst_ptr, src0_ptr); - // copy only values from src0 indicated by indices in src2 - for (int j = 0; j < ne20; ++j) { + // set dst elements indicated by indices in src1 to c + for (int j = 0; j < ne10; ++j) { int id = ids_ptr[j]; GGML_ASSERT(id >= 0 && id < ne00); - dst_ptr[id] = src0_ptr[id]; + dst_ptr[id] = c; } } } -void ggml_compute_forward_where_id( +void ggml_compute_forward_scatter( const ggml_compute_params * params, ggml_tensor * dst) { @@ -11326,11 +11323,11 @@ void ggml_compute_forward_where_id( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_where_id_f32(params, dst); + ggml_compute_forward_scatter_f32(params, dst); } break; default: { - GGML_ABORT("unsupported type for ggml_compute_forward_where_id: %s", ggml_type_name(src0->type)); + GGML_ABORT("unsupported type for ggml_compute_forward_scatter: %s", ggml_type_name(src0->type)); } } } diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 30b3e6d311..4fecd4651e 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -104,7 +104,7 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_hadamard(const struct ggml_compute_params * params, struct ggml_tensor * dst); -void ggml_compute_forward_where_id(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_scatter(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index da2b54e137..4af7f2ba1d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -62,7 +62,7 @@ #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" #include "ggml-cuda/hadamard.cuh" -#include "ggml-cuda/where-id.cuh" +#include "ggml-cuda/scatter.cuh" #include "ggml.h" #include @@ -2776,8 +2776,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_HADAMARD: ggml_cuda_op_hadamard(ctx, dst); break; - case GGML_OP_WHERE_ID: - ggml_cuda_op_where_id(ctx, dst); + case GGML_OP_SCATTER: + ggml_cuda_op_scatter(ctx, dst); break; default: return false; @@ -5020,7 +5020,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_TRI: case GGML_OP_DIAG: case GGML_OP_SOLVE_TRI: - case GGML_OP_WHERE_ID: + case GGML_OP_SCATTER: return true; case GGML_OP_HADAMARD: return (op->ne[0] == 64 || op->ne[0] == 128 || op->ne[0] == 256) && op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-cuda/scatter.cu b/ggml/src/ggml-cuda/scatter.cu new file mode 100644 index 0000000000..990b5cddb7 --- /dev/null +++ b/ggml/src/ggml-cuda/scatter.cu @@ -0,0 +1,72 @@ +#include "scatter.cuh" + +static __global__ void scatter_kernel( + const int32_t * src0, float * dst, const float c, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + size_t nb1, size_t nb2, size_t nb3, + size_t nb01, size_t nb02, size_t nb03 + ) { + + const int64_t total_blocks = ne01 * ne02 * ne03; + + for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) { + + const int64_t i1 = block_idx % ne01; + const int64_t i2 = (block_idx / ne01) % ne02; + const int64_t i3 = block_idx / (ne01 * ne02); + + float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3); + const int * src0_row = (const int *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03); + + for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { + const int32_t id = src0_row[i0]; + dst_row[id] = c; + } + } +} + +void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(int32_t)); + + GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst)); + + float c; + memcpy(&c, (float *) dst->op_params + 0, sizeof(float)); + + // step 1 - copy whole src0 to dst + cudaStream_t main_stream = ctx.stream(); + char * dst_ddc = (char *) dst->data; + char * src0_ddc = (char *) src0->data; + + CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + + // step 2 - set elements in dst indicated by ids to c + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + int threads = std::min((int) ne10, 768); // ids + + int64_t total_blocks = ne11 * ne12 * ne13; + int blocks = (int) std::min((int64_t) 65535, total_blocks); + + scatter_kernel<<>>( + src1_d, dst_d, c, + ne10, ne11, ne12, ne13, + nb1, nb2, nb3, + nb11, nb12, nb13 + ); +} diff --git a/ggml/src/ggml-cuda/scatter.cuh b/ggml/src/ggml-cuda/scatter.cuh new file mode 100644 index 0000000000..b435c992a6 --- /dev/null +++ b/ggml/src/ggml-cuda/scatter.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/where-id.cu b/ggml/src/ggml-cuda/where-id.cu deleted file mode 100644 index 2d9130035a..0000000000 --- a/ggml/src/ggml-cuda/where-id.cu +++ /dev/null @@ -1,78 +0,0 @@ -#include "where-id.cuh" - -static __global__ void where_id_kernel( - const float * src0, const int32_t * src1, float * dst, - int64_t ne10, int64_t ne11, int64_t ne12, int64_t ne13, - size_t nb1, size_t nb2, size_t nb3, - size_t nb01, size_t nb02, size_t nb03, - size_t nb11, size_t nb12, size_t nb13 - ) { - - const int64_t total_blocks = ne11 * ne12 * ne13; - - for (int64_t block_idx = blockIdx.x; block_idx < total_blocks; block_idx += gridDim.x) { - - const int64_t i1 = block_idx % ne11; - const int64_t i2 = (block_idx / ne11) % ne12; - const int64_t i3 = block_idx / (ne11 * ne12); - - float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3); - const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03); - const int * src1_row = (const int *)((const char *)src1 + i1*nb11 + i2*nb12 + i3*nb13); - - for (int64_t i0 = threadIdx.x; i0 < ne10; i0 += blockDim.x) { - const int32_t id = src1_row[i0]; - dst_row[id] = src0_row[id]; - } - } -} - -void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * src2 = dst->src[2]; - - GGML_TENSOR_TERNARY_OP_LOCALS - - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(src2)); - - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(src2->type == GGML_TYPE_I32); - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - GGML_ASSERT(nb20 == sizeof(int32_t)); - - GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); - GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst)); - - // step 1 - copy whole src1 to dst - cudaStream_t main_stream = ctx.stream(); - char * dst_ddc = (char *) dst->data; - char * src1_ddc = (char *) src1->data; - - CUDA_CHECK(cudaMemcpyAsync(dst_ddc, src1_ddc, ggml_nbytes(src1), cudaMemcpyDeviceToDevice, main_stream)); - - // step 2 - copy elements from src0 indicated by ids to dst - const float * src0_d = (const float *) src0->data; - const int32_t * src2_d = (const int32_t *) src2->data; - float * dst_d = (float *) dst->data; - - int threads = std::min((int) ne20, 768); // ids - - int64_t total_blocks = ne21 * ne22 * ne23; - int blocks = (int) std::min((int64_t) 65535, total_blocks); - - where_id_kernel<<>>( - src0_d, src2_d, dst_d, - ne20, ne21, ne22, ne23, - nb1, nb2, nb3, - nb01, nb02, nb03, - nb21, nb22, nb23 - ); -} diff --git a/ggml/src/ggml-cuda/where-id.cuh b/ggml/src/ggml-cuda/where-id.cuh deleted file mode 100644 index bf3ea095a8..0000000000 --- a/ggml/src/ggml-cuda/where-id.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_op_where_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7132c1f215..809e71d213 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1033,7 +1033,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SOLVE_TRI", "GATED_DELTA_NET", "HADAMARD", - "WHERE_ID", + "SCATTER", "UNARY", @@ -1145,7 +1145,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "A X = B, A triangular, solve X", "gated_delta_net(q, k, v, g, beta, s)", "hadamard(x)", - "where_id(x,y,ids)", + "scatter(x,ids,c)", "unary(x)", @@ -6203,25 +6203,26 @@ struct ggml_tensor * ggml_hadamard( return result; } -// ggml_where_id +// ggml_scatter -struct ggml_tensor * ggml_where_id( +struct ggml_tensor * ggml_scatter( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * ids) { + struct ggml_tensor * ids, + float c) { GGML_ASSERT(a->type == GGML_TYPE_F32); - GGML_ASSERT(b->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[1] == ids->ne[1]); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne); - result->op = GGML_OP_WHERE_ID; + float params[1] = { c }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_SCATTER; result->src[0] = a; - result->src[1] = b; - result->src[2] = ids; + result->src[1] = ids; return result; } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8224e4873f..29d804638c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2177,7 +2177,10 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY); // modify it by unmasking tokens that are in top_k indices - ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k); + ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0); + + // combine with the original kq mask + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask_f32); kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type); ggml_tensor * q = q_cur;