Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.
This commit is contained in:
Gaurav Garg 2026-02-16 15:39:26 +05:30
parent fd24533e89
commit aa8b62105c
3 changed files with 27 additions and 2 deletions

View File

@ -141,6 +141,8 @@ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t d
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
@ -175,7 +177,7 @@ static const ggml_backend_device_i ggml_backend_meta_device_iface = {
/* .get_props = */ ggml_backend_meta_device_get_props,
/* .init_backend = */ ggml_backend_meta_device_init_backend,
/* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type,
/* .get_host_buffer_type = */ nullptr,
/* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
/* .buffer_from_host_ptr = */ nullptr,
/* .supports_op = */ ggml_backend_meta_device_supports_op,
/* .supports_buft = */ ggml_backend_meta_device_supports_buft,
@ -346,6 +348,27 @@ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_
return &result.first->second;
}
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
ggml_backend_buffer_type_t host_buft = nullptr;
for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
if (simple_host_buft == nullptr) {
return nullptr;
}
if (host_buft == nullptr) {
host_buft = simple_host_buft;
} else if (host_buft != simple_host_buft) {
// if different simple devices have different host buffer types,
// we cannot provide a single host buffer type for the meta device
return nullptr;
}
}
return host_buft;
}
size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;

View File

@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ nullptr,
/* .set_tensor_2d = */ nullptr,
/* .get_tensor_2d = */ nullptr,
/* .cpy_tensor = */ nullptr,
/* .clear = */ ggml_backend_amx_buffer_clear,
/* .reset = */ nullptr,

View File

@ -130,7 +130,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
split_state.axis = get_split_axis();
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
const int64_t ne_full = tensor->ne[split_state.axis];
const int64_t granularity = get_split_granularity();
const int64_t granularity = get_split_granularity(split_state.axis);
GGML_ASSERT(ne_full % granularity == 0);
const float * tensor_split = ud->model->tensor_split();
std::vector<float> tensor_split_scan;