diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 5b939fdf62..f37c2e2388 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -141,6 +141,8 @@ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t d static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev); +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev); + static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { GGML_ASSERT(ggml_backend_dev_is_meta(dev)); const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; @@ -175,7 +177,7 @@ static const ggml_backend_device_i ggml_backend_meta_device_iface = { /* .get_props = */ ggml_backend_meta_device_get_props, /* .init_backend = */ ggml_backend_meta_device_init_backend, /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type, - /* .get_host_buffer_type = */ nullptr, + /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type, /* .buffer_from_host_ptr = */ nullptr, /* .supports_op = */ ggml_backend_meta_device_supports_op, /* .supports_buft = */ ggml_backend_meta_device_supports_buft, @@ -346,6 +348,27 @@ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_ return &result.first->second; } +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + ggml_backend_buffer_type_t host_buft = nullptr; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev); + if (simple_host_buft == nullptr) { + return nullptr; + } + if (host_buft == nullptr) { + host_buft = simple_host_buft; + } else if (host_buft != simple_host_buft) { + // if different simple devices have different host buffer types, + // we cannot provide a single host buffer type for the meta device + return nullptr; + } + } + return host_buft; +} + size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) { GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 895a571375..791b051cb2 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, /* .cpy_tensor = */ nullptr, /* .clear = */ ggml_backend_amx_buffer_clear, /* .reset = */ nullptr, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b0713cab45..7c5caf9c8c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -130,7 +130,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str split_state.axis = get_split_axis(); if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { const int64_t ne_full = tensor->ne[split_state.axis]; - const int64_t granularity = get_split_granularity(); + const int64_t granularity = get_split_granularity(split_state.axis); GGML_ASSERT(ne_full % granularity == 0); const float * tensor_split = ud->model->tensor_split(); std::vector tensor_split_scan;