ggml : added f16 version of GGML_OP_SCATTER

This commit is contained in:
Stanisław Szymczyk 2026-03-25 11:08:29 +01:00
parent 54945c7ec1
commit 5677f082b0
4 changed files with 99 additions and 13 deletions

View File

@ -11318,6 +11318,68 @@ static void ggml_compute_forward_scatter_f32(
}
}
static void ggml_compute_forward_scatter_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
const bool inplace = ggml_get_op_params_i32(dst, 1);
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0);
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 );
const int32_t * ids_ptr = (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
// copy whole row from src0
if (!inplace) {
// ggml_vec_cpy_f16(ne00, dst_ptr, src0_ptr)
for (int i = 0; i < ne00; ++i) {
dst_ptr[i] = src0_ptr[i];
}
}
// set dst elements indicated by indices in src1 to c
for (int j = 0; j < ne10; ++j) {
int id = ids_ptr[j];
GGML_ASSERT(id >= 0 && id < ne00);
dst_ptr[id] = c;
}
}
}
void ggml_compute_forward_scatter(
const ggml_compute_params * params,
ggml_tensor * dst) {
@ -11329,6 +11391,10 @@ void ggml_compute_forward_scatter(
{
ggml_compute_forward_scatter_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_scatter_f16(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_scatter: %s", ggml_type_name(src0->type));

View File

@ -1,7 +1,9 @@
#include "scatter.cuh"
#include "convert.cuh"
template <typename T>
static __global__ void scatter_kernel(
const int32_t * src0, float * dst, const float c,
const int32_t * src0, T * dst, const T c,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
size_t nb1, size_t nb2, size_t nb3,
size_t nb01, size_t nb02, size_t nb03
@ -15,7 +17,7 @@ static __global__ void scatter_kernel(
const int64_t i2 = (block_idx / ne01) % ne02;
const int64_t i3 = block_idx / (ne01 * ne02);
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3);
T * dst_row = (T *)((char *)dst + i1*nb1 + i2*nb2 + i3*nb3);
const int * src0_row = (const int *)((const char *)src0 + i1*nb01 + i2*nb02 + i3*nb03);
for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) {
@ -35,11 +37,9 @@ void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == src0->type);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(int32_t));
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(dst));
@ -58,17 +58,31 @@ void ggml_cuda_op_scatter(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// step 2 - set elements in dst indicated by ids to c
const int32_t * src1_d = (const int32_t *) src1->data;
float * dst_d = (float *) dst->data;
void * dst_d = dst->data;
int threads = std::min((int) ne10, 512); // ids
int64_t total_blocks = ne11 * ne12 * ne13;
int blocks = (int) std::min((int64_t) 65535, total_blocks);
scatter_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src1_d, dst_d, c,
ne10, ne11, ne12, ne13,
nb1, nb2, nb3,
nb11, nb12, nb13
);
switch (dst->type) {
case GGML_TYPE_F32:
scatter_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src1_d, (float *) dst_d, c,
ne10, ne11, ne12, ne13,
nb1, nb2, nb3,
nb11, nb12, nb13
);
break;
case GGML_TYPE_F16:
scatter_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src1_d, (half *) dst_d, ggml_cuda_cast<half>(c),
ne10, ne11, ne12, ne13,
nb1, nb2, nb3,
nb11, nb12, nb13
);
break;
default:
GGML_ABORT("unsupported type");
}
}

View File

@ -6212,7 +6212,7 @@ static struct ggml_tensor * ggml_scatter_impl(
float c,
bool inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[1] == ids->ne[1]);

View File

@ -8538,6 +8538,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, false));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, false));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 1, 1, 1}, {3, 1, 1, 1}, 0.0f, false));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {10, 10, 10, 10}, {3, 10, 10, 10}, 0.0f, false));
return test_cases;
}
@ -8798,6 +8802,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
// scatter
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F32, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, false));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, true));
test_cases.emplace_back(new test_scatter(GGML_TYPE_F16, GGML_TYPE_I32, {65536, 1, 1, 1}, {2048, 1, 1, 1}, 0.0f, false));
return test_cases;
}