diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 31040e278b..8f9f082c43 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2229,8 +2229,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg } } +static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0)); + + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const auto [ir0, ir1] = get_thread_range(params, dst); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne2*ne1); + const int64_t i02 = (ir - i03*ne2*ne1)/ne1; + const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + + ggml_vec_set_f16(ne0, dst_ptr, c); + } +} + void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) { - ggml_compute_forward_fill_f32(params, dst); + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fill_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_fill_f16(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type)); + } + } } // ggml_compute_tri diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9744813f45..bc5bae4096 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5177,7 +5177,7 @@ static struct ggml_tensor * ggml_fill_impl( struct ggml_tensor * a, float c, bool inplace) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f8318a14ef..9abb98cdbd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8370,6 +8370,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 })); test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 })); test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 })); + test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F16, { 303, 207, 11, 3 })); + test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F16, { 800, 600, 4, 4 })); + test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F16, { 2048, 512, 2, 2 })); test_cases.emplace_back(new test_diag()); test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 }));