From 5677f082b0d37ec6bc9eaf6d755e22197c51948a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Wed, 25 Mar 2026 11:08:29 +0100 Subject: [PATCH] ggml : added f16 version of GGML_OP_SCATTER --- ggml/src/ggml-cpu/ops.cpp | 66 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/scatter.cu | 38 +++++++++++++------- ggml/src/ggml.c | 2 +- tests/test-backend-ops.cpp | 6 ++++ 4 files changed, 99 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 86eeaa479a..31040e278b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11318,6 +11318,68 @@ static void ggml_compute_forward_scatter_f32( } } +static void ggml_compute_forward_scatter_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0)); + const bool inplace = ggml_get_op_params_i32(dst, 1); + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 ); + const int32_t * ids_ptr = (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + + // copy whole row from src0 + if (!inplace) { + // ggml_vec_cpy_f16(ne00, dst_ptr, src0_ptr) + for (int i = 0; i < ne00; ++i) { + dst_ptr[i] = src0_ptr[i]; + } + } + + // set dst elements indicated by indices in src1 to c + for (int j = 0; j < ne10; ++j) { + int id = ids_ptr[j]; + GGML_ASSERT(id >= 0 && id < ne00); + dst_ptr[id] = c; + } + } +} + void ggml_compute_forward_scatter( const ggml_compute_params * params, ggml_tensor * dst) { @@ -11329,6 +11391,10 @@ void ggml_compute_forward_scatter( { ggml_compute_forward_scatter_f32(params, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_scatter_f16(params, dst); + } break; default: { GGML_ABORT("unsupported type for ggml_compute_forward_scatter: %s", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-cuda/scatter.cu b/ggml/src/ggml-cuda/scatter.cu index 0c252dad65..6dacb28b52 100644 --- a/ggml/src/ggml-cuda/scatter.cu +++ b/ggml/src/ggml-cuda/scatter.cu @@ -1,7 +1,9 @@ #include "scatter.cuh" +#include "convert.cuh" +template static __global__ void scatter_kernel( - const int32_t * src0, float * dst, const float c, + const int32_t * src0, T * dst, const T c, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, size_t nb1, size_t nb2, size_t nb3, size_t nb01, size_t nb02, size_t nb03 @@ -15,7 +17,7 @@ static __global__ void scatter_kernel( const int64_t i2 = (block_idx / ne01) % ne02; const int64_t i3 = block_idx / (ne01 * ne02); - float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3); + T * dst_row = (T *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3); const int * src0_row = (const int *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03); for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { @@ -35,11 +37,9 @@ void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == src0->type); GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(int32_t)); GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst)); @@ -58,17 +58,31 @@ void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // step 2 - set elements in dst indicated by ids to c const int32_t * src1_d = (const int32_t *) src1->data; - float * dst_d = (float *) dst->data; + void * dst_d = dst->data; int threads = std::min((int) ne10, 512); // ids int64_t total_blocks = ne11 * ne12 * ne13; int blocks = (int) std::min((int64_t) 65535, total_blocks); - scatter_kernel<<>>( - src1_d, dst_d, c, - ne10, ne11, ne12, ne13, - nb1, nb2, nb3, - nb11, nb12, nb13 - ); + switch (dst->type) { + case GGML_TYPE_F32: + scatter_kernel<<>>( + src1_d, (float *) dst_d, c, + ne10, ne11, ne12, ne13, + nb1, nb2, nb3, + nb11, nb12, nb13 + ); + break; + case GGML_TYPE_F16: + scatter_kernel<<>>( + src1_d, (half *) dst_d, ggml_cuda_cast(c), + ne10, ne11, ne12, ne13, + nb1, nb2, nb3, + nb11, nb12, nb13 + ); + break; + default: + GGML_ABORT("unsupported type"); + } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 82a889cbfa..9744813f45 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6212,7 +6212,7 @@ static struct ggml_tensor * ggml_scatter_impl( float c, bool inplace) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[1] == ids->ne[1]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b615702a29..f8318a14ef 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8538,6 +8538,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, false)); test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, true)); test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, false)); + test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, true)); + test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, false)); + test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, true)); + test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, false)); return test_cases; } @@ -8798,6 +8802,8 @@ static std::vector> make_test_cases_perf() { // scatter test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, true)); test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, false)); + test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, true)); + test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, false)); return test_cases; }