ggml : added f16 version of GGML_OP_FILL
This commit is contained in:
parent
5677f082b0
commit
1c830a178b
|
|
@ -2229,8 +2229,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, dst);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne2*ne1);
|
||||
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
||||
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
||||
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
||||
|
||||
ggml_vec_set_f16(ne0, dst_ptr, c);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
ggml_compute_forward_fill_f32(params, dst);
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_fill_f32(params, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_fill_f16(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_tri
|
||||
|
|
|
|||
|
|
@ -5177,7 +5177,7 @@ static struct ggml_tensor * ggml_fill_impl(
|
|||
struct ggml_tensor * a,
|
||||
float c,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
|
|
|||
|
|
@ -8370,6 +8370,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 }));
|
||||
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
|
||||
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
|
||||
test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F16, { 303, 207, 11, 3 }));
|
||||
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F16, { 800, 600, 4, 4 }));
|
||||
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F16, { 2048, 512, 2, 2 }));
|
||||
|
||||
test_cases.emplace_back(new test_diag());
|
||||
test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 }));
|
||||
|
|
|
|||
Loading…
Reference in New Issue