ggml : added f16 version of GGML_OP_FILL

This commit is contained in:
Stanisław Szymczyk 2026-03-25 11:35:13 +01:00
parent 5677f082b0
commit 1c830a178b
3 changed files with 39 additions and 2 deletions

View File

@ -2229,8 +2229,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg
}
}
static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
const auto [ir0, ir1] = get_thread_range(params, dst);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne2*ne1);
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
ggml_vec_set_f16(ne0, dst_ptr, c);
}
}
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
ggml_compute_forward_fill_f32(params, dst);
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_fill_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_fill_f16(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type));
}
}
}
// ggml_compute_tri

View File

@ -5177,7 +5177,7 @@ static struct ggml_tensor * ggml_fill_impl(
struct ggml_tensor * a,
float c,
bool inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16);
GGML_ASSERT(ggml_is_contiguous(a));
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

View File

@ -8370,6 +8370,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F16, { 303, 207, 11, 3 }));
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F16, { 800, 600, 4, 4 }));
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F16, { 2048, 512, 2, 2 }));
test_cases.emplace_back(new test_diag());
test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 }));