Merge pull request #6 from gaugarg-nv/get_host_buffer_type

Support device-specific host buffer types in meta backend
This commit is contained in:
Johannes Gäßler 2026-02-16 15:11:08 +01:00 committed by GitHub
commit f0198ef6fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 2 deletions

View File

@ -141,6 +141,8 @@ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t d
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
@ -175,7 +177,7 @@ static const ggml_backend_device_i ggml_backend_meta_device_iface = {
/* .get_props = */ ggml_backend_meta_device_get_props,
/* .init_backend = */ ggml_backend_meta_device_init_backend,
/* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type,
/* .get_host_buffer_type = */ nullptr,
/* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
/* .buffer_from_host_ptr = */ nullptr,
/* .supports_op = */ ggml_backend_meta_device_supports_op,
/* .supports_buft = */ ggml_backend_meta_device_supports_buft,
@ -346,6 +348,27 @@ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_
return &result.first->second;
}
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
ggml_backend_buffer_type_t host_buft = nullptr;
for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
if (simple_host_buft == nullptr) {
return nullptr;
}
if (host_buft == nullptr) {
host_buft = simple_host_buft;
} else if (host_buft != simple_host_buft) {
// if different simple devices have different host buffer types,
// we cannot provide a single host buffer type for the meta device
return nullptr;
}
}
return host_buft;
}
size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;

View File

@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ nullptr,
/* .set_tensor_2d = */ nullptr,
/* .get_tensor_2d = */ nullptr,
/* .cpy_tensor = */ nullptr,
/* .clear = */ ggml_backend_amx_buffer_clear,
/* .reset = */ nullptr,

View File

@ -130,7 +130,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
split_state.axis = get_split_axis();
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
const int64_t ne_full = tensor->ne[split_state.axis];
const int64_t granularity = get_split_granularity();
const int64_t granularity = get_split_granularity(split_state.axis);
GGML_ASSERT(ne_full % granularity == 0);
const float * tensor_split = ud->model->tensor_split();
std::vector<float> tensor_split_scan;