From 7758a58aa83a82be7d2bdb7035f63f37ff2cf1d5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 18 Jan 2026 15:37:25 +0200 Subject: [PATCH] metal : support virtual devices --- ggml/src/ggml-metal/ggml-metal-context.h | 2 + ggml/src/ggml-metal/ggml-metal-context.m | 10 +- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-device.h | 7 +- ggml/src/ggml-metal/ggml-metal-device.m | 18 +- ggml/src/ggml-metal/ggml-metal.cpp | 280 +++++++++++++++------- 6 files changed, 224 insertions(+), 101 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.h b/ggml/src/ggml-metal/ggml-metal-context.h index ec2b686b73..bb6bf7da9b 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.h +++ b/ggml/src/ggml-metal/ggml-metal-context.h @@ -15,6 +15,8 @@ typedef struct ggml_metal * ggml_metal_t; ggml_metal_t ggml_metal_init(ggml_metal_device_t dev); void ggml_metal_free(ggml_metal_t ctx); +const char * ggml_metal_get_name(ggml_metal_t ctx); + void ggml_metal_synchronize(ggml_metal_t ctx); void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 42a35736ee..13bc3b650d 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -24,6 +24,8 @@ struct ggml_metal_command_buffer { }; struct ggml_metal { + char name[128]; + ggml_metal_device_t dev; ggml_metal_library_t lib; @@ -117,7 +119,9 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { } } - //const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + snprintf(res->name, sizeof(res->name), "%s", props_dev->name); res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); @@ -209,6 +213,10 @@ void ggml_metal_free(ggml_metal_t ctx) { free(ctx); } +const char * ggml_metal_get_name(ggml_metal_t ctx) { + return ctx->name; +} + void ggml_metal_synchronize(ggml_metal_t ctx) { // wait for any backend operations to finish if (ctx->cmd_buf_last) { diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 04c6137c5a..377b0d3eb8 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -17,10 +17,12 @@ struct ggml_metal_device_deleter { typedef std::unique_ptr ggml_metal_device_ptr; -ggml_metal_device_t ggml_metal_device_get(void) { - static ggml_metal_device_ptr ctx { ggml_metal_device_init() }; +ggml_metal_device_t ggml_metal_device_get(int device) { + static std::vector devs; - return ctx.get(); + devs.emplace_back(ggml_metal_device_init(device)); + + return devs.back().get(); } struct ggml_metal_pipelines { diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 3d01c56fb8..ca499d8195 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -205,7 +205,9 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); // struct ggml_metal_device_props { + int device; char name[128]; + char desc[128]; size_t max_buffer_size; size_t max_working_set_size; @@ -224,11 +226,10 @@ struct ggml_metal_device_props { int op_offload_min_batch_size; }; -ggml_metal_device_t ggml_metal_device_init(void); +ggml_metal_device_t ggml_metal_device_init(int device); void ggml_metal_device_free(ggml_metal_device_t dev); -// return a singleton that is automatically destroyed when the program exits -ggml_metal_device_t ggml_metal_device_get(void); +ggml_metal_device_t ggml_metal_device_get(int device); void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 7f9c384c34..8f74fe4427 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -24,9 +24,6 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; static const NSInteger MTLGPUFamilyMetal4_GGML = 5002; -// virtual address for GPU memory allocations -static atomic_uintptr_t g_addr_device = 0x000000400ULL; - #if !GGML_METAL_EMBED_LIBRARY // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -523,6 +520,9 @@ struct ggml_metal_device { ggml_metal_library_t library; struct ggml_metal_device_props props; + + // virtual address for GPU memory allocations + atomic_uintptr_t addr_virt; }; // @@ -618,7 +618,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { free(rsets); } -ggml_metal_device_t ggml_metal_device_init(void) { +ggml_metal_device_t ggml_metal_device_init(int device) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); assert(dev != NULL); @@ -632,6 +632,9 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); } + dev->addr_virt = 0x000000400ULL; + + dev->props.device = device; dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -792,7 +795,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.max_working_set_size = dev->mtl_device.maxBufferLength; } - strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); + snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device); + snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]); dev->library = ggml_metal_library_init(dev); if (!dev->library) { @@ -1344,8 +1348,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->all_data = ggml_metal_host_malloc(size_aligned); res->is_shared = true; } else { - // use virtual address from g_addr_device counter - res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed); + // use virtual address + res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed); res->is_shared = false; } res->all_size = size_aligned; diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 56b59f0afd..20d4f142f5 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -7,11 +7,12 @@ #include "ggml-metal-context.h" #include "ggml-metal-ops.h" -// globals +#define GGML_METAL_NAME "MTL" +#define GGML_METAL_MAX_DEVICES 16 -// initialized in ggml_backend_metal_reg -static ggml_backend_reg g_ggml_metal_reg; -static ggml_backend_device g_ggml_metal_device; +// number of Metal devices +// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices +static int g_devices = 1; //////////////////////////////////////////////////////////////////////////////// // backend interface @@ -169,6 +170,11 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { // buffer types // +struct ggml_backend_metal_buffer_type { + int device; + std::string name; +}; + // common method for allocating shread or private Metal buffers static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; @@ -218,9 +224,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ // default (shared) buffer type static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -249,29 +255,41 @@ static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static ggml_backend_buffer_type bufts[GGML_METAL_MAX_DEVICES]; + + static bool initialized = false; + if (!initialized) { + for (int i = 0; i < g_devices; ++i) { + bufts[i] = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ new ggml_backend_metal_buffer_type{i, GGML_METAL_NAME + std::to_string(i)}, + }; + } + + initialized = true; }; - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // default (private) buffer type static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Private"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -300,29 +318,41 @@ static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_t GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static ggml_backend_buffer_type bufts[GGML_METAL_MAX_DEVICES]; + + static bool initialized = false; + if (!initialized) { + for (int i = 0; i < g_devices; ++i) { + bufts[i] = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ new ggml_backend_metal_buffer_type{i, GGML_METAL_NAME + std::to_string(i) + "_Private"}, + }; + } + + initialized = true; }; - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // mapped buffer type static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -352,31 +382,43 @@ static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { - // note: not obvious, but this buffer type still needs to implement .alloc_buffer: - // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 - static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static ggml_backend_buffer_type bufts[GGML_METAL_MAX_DEVICES]; + + static bool initialized = false; + if (!initialized) { + for (int i = 0; i < g_devices; ++i) { + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + bufts[i] = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ new ggml_backend_metal_buffer_type{i, GGML_METAL_NAME + std::to_string(i) + "_Mapped"}, + }; + } + + initialized = true; }; - return &ggml_backend_buffer_type_mapped_metal; + return &bufts[device]; } // backend static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; + ggml_metal_t ctx = (ggml_metal_t)backend->context; - GGML_UNUSED(backend); + return ggml_metal_get_name(ctx); } static void ggml_backend_metal_free(ggml_backend_t backend) { @@ -435,7 +477,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { ggml_metal_t ctx = (ggml_metal_t)backend->context; ggml_metal_set_n_cb(ctx, n_cb); - } static ggml_backend_i ggml_backend_metal_i = { @@ -519,15 +560,17 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { // backend device static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { - return "Metal"; + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - GGML_UNUSED(dev); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return props_dev->name; } static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - return ggml_metal_device_get_props(ctx_dev)->name; + return ggml_metal_device_get_props(ctx_dev)->desc; } static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { @@ -557,7 +600,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac }; } -static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { +static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; ggml_metal_t ctx = ggml_metal_init(ctx_dev); @@ -587,7 +630,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); - return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); + return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device); } static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { @@ -595,7 +638,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backen ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size); } static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { @@ -606,9 +651,10 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { return + buft->device == dev && ( buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name); GGML_UNUSED(dev); } @@ -638,7 +684,7 @@ static ggml_backend_device_i ggml_backend_metal_device_i = { /* .get_memory = */ ggml_backend_metal_device_get_memory, /* .get_type = */ ggml_backend_metal_device_get_type, /* .get_props = */ ggml_backend_metal_device_get_props, - /* .init_backend = */ ggml_backend_metal_device_init, + /* .init_backend = */ ggml_backend_metal_device_init_backend, /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, @@ -652,25 +698,45 @@ static ggml_backend_device_i ggml_backend_metal_device_i = { // backend registry +struct ggml_backend_metal_reg { + std::vector devices; +}; + +typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t; + +static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) { + ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg; + + return ctx; +} + +static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) { + delete ctx; +} + +struct ggml_backend_metal_reg_deleter { + void operator()(ggml_backend_metal_reg_t ctx) { + ggml_backend_metal_reg_free(ctx); + } +}; + +typedef std::unique_ptr ggml_backend_metal_reg_ptr; + static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { - return "Metal"; + return GGML_METAL_NAME; GGML_UNUSED(reg); } static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { - return 1; - - GGML_UNUSED(reg); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + return ctx->devices.size(); } static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); - - return &g_ggml_metal_device; - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; } static ggml_backend_feature g_ggml_backend_metal_features[] = { @@ -698,27 +764,67 @@ static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const static ggml_backend_reg_i ggml_backend_metal_reg_i = { /* .get_name = */ ggml_backend_metal_reg_get_name, - /* .device_count = */ ggml_backend_metal_reg_device_count, - /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_device_count = */ ggml_backend_metal_reg_device_count, + /* .get_device = */ ggml_backend_metal_reg_device_get, /* .get_proc_address = */ ggml_backend_metal_get_proc_address, }; -ggml_backend_reg_t ggml_backend_metal_reg(void) { - { - g_ggml_metal_reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_metal_reg_i, - /* .context = */ NULL, - }; +static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) { + return new ggml_backend_device { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ reg, + /* .context = */ ggml_metal_device_get(device), + }; +} - g_ggml_metal_device = { - /* .iface = */ ggml_backend_metal_device_i, - /* .reg = */ &g_ggml_metal_reg, - /* .context = */ ggml_metal_device_get(), - }; +static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) { + delete dev; +} + +struct ggml_backend_device_deleter { + void operator()(ggml_backend_dev_t ctx) { + ggml_backend_metal_device_free(ctx); + } +}; + +typedef std::unique_ptr ggml_backend_device_ptr; + +ggml_backend_reg_t ggml_backend_metal_reg(void) { + static ggml_backend_reg reg; + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + + const char * env = getenv("GGML_METAL_DEVICES"); + if (env) { + g_devices = atoi(env); + } + + static std::vector devs; + + if (!initialized) { + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); + + for (int i = 0; i < g_devices; ++i) { + auto * dev = ggml_backend_metal_device_init(®, i); + devs.emplace_back(dev); + + reg_ctx->devices.push_back(dev); + } + + reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ reg_ctx.get(), + }; + } + + initialized = true; } - return &g_ggml_metal_reg; + return ® } GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)