Merge pull request #6 from gaugarg-nv/get_host_buffer_type
Support device-specific host buffer types in meta backend
This commit is contained in:
commit
f0198ef6fc
|
|
@ -141,6 +141,8 @@ static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t d
|
|||
|
||||
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev);
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev);
|
||||
|
||||
static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
||||
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
||||
|
|
@ -175,7 +177,7 @@ static const ggml_backend_device_i ggml_backend_meta_device_iface = {
|
|||
/* .get_props = */ ggml_backend_meta_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_meta_device_init_backend,
|
||||
/* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ nullptr,
|
||||
/* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type,
|
||||
/* .buffer_from_host_ptr = */ nullptr,
|
||||
/* .supports_op = */ ggml_backend_meta_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_meta_device_supports_buft,
|
||||
|
|
@ -346,6 +348,27 @@ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_
|
|||
return &result.first->second;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
||||
GGML_ASSERT(ggml_backend_dev_is_meta(dev));
|
||||
const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
||||
|
||||
ggml_backend_buffer_type_t host_buft = nullptr;
|
||||
for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) {
|
||||
ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev);
|
||||
if (simple_host_buft == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (host_buft == nullptr) {
|
||||
host_buft = simple_host_buft;
|
||||
} else if (host_buft != simple_host_buft) {
|
||||
// if different simple devices have different host buffer types,
|
||||
// we cannot provide a single host buffer type for the meta device
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return host_buft;
|
||||
}
|
||||
|
||||
size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) {
|
||||
GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft));
|
||||
const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context;
|
||||
|
|
|
|||
|
|
@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
|
|||
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
|
||||
/* .get_tensor = */ nullptr,
|
||||
/* .set_tensor_2d = */ nullptr,
|
||||
/* .get_tensor_2d = */ nullptr,
|
||||
/* .cpy_tensor = */ nullptr,
|
||||
/* .clear = */ ggml_backend_amx_buffer_clear,
|
||||
/* .reset = */ nullptr,
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
|
|||
split_state.axis = get_split_axis();
|
||||
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
|
||||
const int64_t ne_full = tensor->ne[split_state.axis];
|
||||
const int64_t granularity = get_split_granularity();
|
||||
const int64_t granularity = get_split_granularity(split_state.axis);
|
||||
GGML_ASSERT(ne_full % granularity == 0);
|
||||
const float * tensor_split = ud->model->tensor_split();
|
||||
std::vector<float> tensor_split_scan;
|
||||
|
|
|
|||
Loading…
Reference in New Issue