From 97b96c1ad308eeab4a0bdfd98a952e588799b98b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 6 Sep 2025 11:53:51 +0300 Subject: [PATCH] metal : make the backend async ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 330 +++++++++++++++++++++++++------ 1 file changed, 269 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index e76fb71263..9626dd3bd5 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -803,6 +803,12 @@ struct ggml_backend_metal_context { // n_cb command buffers + 1 used by the main thread struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + // extra command buffers for things like getting, setting and copying tensors + NSMutableArray * cmd_bufs_ext; + + id cmd_buf_last; + id cmd_buf_ext_last; + // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; void * abort_callback_data; @@ -1073,6 +1079,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->cmd_bufs[i].mem_pool->device = device; } + ctx->cmd_bufs_ext = [[NSMutableArray alloc] init]; + + ctx->cmd_buf_last = nil; + ctx->cmd_buf_ext_last = nil; + #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); @@ -1666,11 +1677,15 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->queue release]; for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - // ctx->cmd_bufs[i].obj is auto released + if (ctx->cmd_bufs[i].obj) { + [ctx->cmd_bufs[i].obj release]; + } ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); } + [ctx->cmd_bufs_ext release]; + dispatch_release(ctx->d_queue); free(ctx); @@ -5778,81 +5793,110 @@ static enum ggml_status ggml_metal_graph_compute( } } + // wait for any previous processing + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + // the main thread commits the first few commands immediately // cmd_buf[n_cb] { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[n_cb].obj) { + [ctx->cmd_bufs[n_cb].obj release]; + } ctx->cmd_bufs[n_cb].obj = cmd_buf; [cmd_buf enqueue]; + ctx->cmd_buf_last = cmd_buf; + ctx->encode_async(n_cb); } // prepare the rest of the command buffers asynchronously // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[cb_idx].obj) { + [ctx->cmd_bufs[cb_idx].obj release]; + } ctx->cmd_bufs[cb_idx].obj = cmd_buf; // always enqueue the first two command buffers // enqueue all of the command buffers if we don't need to abort if (cb_idx < 2 || ctx->abort_callback == NULL) { [cmd_buf enqueue]; + ctx->cmd_buf_last = cmd_buf; } } dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); - // wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) - { - id cmd_buf = ctx->cmd_bufs[n_cb].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - } - - for (int i = 0; i < n_cb; ++i) { - id cmd_buf = ctx->cmd_bufs[i].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - - id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); - if (!next_buffer) { - continue; - } - - const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; - } + // for debugging: block until graph is computed + //[ctx->cmd_buf_last waitUntilCompleted]; + // enter here only when capturing in order to wait for all computation to finish + // otherwise, we leave the graph to compute asynchronously if (!should_capture && ctx->capture_started) { + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + [ctx->capture_scope endScope]; [[MTLCaptureManager sharedCaptureManager] stopCapture]; } @@ -6034,7 +6078,7 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty } static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; + return false; GGML_UNUSED(buft); } @@ -6184,6 +6228,154 @@ static void ggml_backend_metal_free(ggml_backend_t backend) { free(backend); } +static void ggml_backend_metal_synchronize(ggml_backend_t backend) { + struct ggml_backend_metal_context * ctx = backend->context; + + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + + if (ctx->cmd_buf_ext_last) { + [ctx->cmd_buf_ext_last waitUntilCompleted]; + ctx->cmd_buf_ext_last = nil; + } + + for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { + id cmd_buf = ctx->cmd_bufs_ext[i]; + + // check status and assert that the command buffer completed successfully + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + GGML_ABORT("fatal error"); + } + + //printf("releasing buffer %d\n", (int) i); + [cmd_buf release]; + } + [ctx->cmd_bufs_ext removeAllObjects]; +} + +static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *)buf->context; + + @autoreleasepool { + id device = ctx_dev->mtl_device; + + id buf_src = [device newBufferWithBytes:data + length:size + options:MTLResourceStorageModeShared]; + + size_t tensor_offset = (uintptr_t)tensor->data + offset; + + // find which buffer contains this tensor + for (int i = 0; i < buf_ctx->n_buffers; i++) { + if (tensor_offset >= (uintptr_t) buf_ctx->buffers[i].data && + tensor_offset < (uintptr_t) buf_ctx->buffers[i].data + buf_ctx->buffers[i].size) { + + const size_t buf_dst_offset = tensor_offset - (uintptr_t) buf_ctx->buffers[i].data; + + id buf_dst = buf_ctx->buffers[i].metal; + + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf enqueue]; + + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:buf_dst + destinationOffset:buf_dst_offset + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + //[cmd_buf waitUntilCompleted]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_ext_last = cmd_buf; + + [cmd_buf retain]; + return; + } + } + + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } +} + +static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *)buf->context; + + @autoreleasepool { + id device = ctx_dev->mtl_device; + + id buf_dst = [device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + const size_t tensor_offset = (uintptr_t)tensor->data + offset; + + // find which buffer contains this tensor data + for (int i = 0; i < buf_ctx->n_buffers; i++) { + if (tensor_offset >= (uintptr_t) buf_ctx->buffers[i].data && + tensor_offset < (uintptr_t) buf_ctx->buffers[i].data + buf_ctx->buffers[i].size) { + + const size_t buf_src_offset = tensor_offset - (uintptr_t) buf_ctx->buffers[i].data; + + id buf_src = buf_ctx->buffers[i].metal; + + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf enqueue]; + + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:buf_src_offset + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + //[cmd_buf waitUntilCompleted]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_ext_last = cmd_buf; + + [cmd_buf retain]; + return; + } + } + + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } +} + +static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + return false; + + GGML_UNUSED(backend_src); + GGML_UNUSED(backend_dst); + GGML_UNUSED(src); + GGML_UNUSED(dst); +} + static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { return ggml_metal_graph_compute(backend, cgraph); } @@ -6214,7 +6406,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; - id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + + ggml_metal_mem_pool_reset(mem_pool); id encoder = [cmd_buf computeCommandEncoder]; @@ -6228,9 +6423,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const bool should_capture = ctx->capture_next_compute; - struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; - ggml_metal_mem_pool_reset(mem_pool); - for (int idx = node_start; idx < node_end;) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; @@ -6264,15 +6456,19 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { static struct ggml_backend_i ggml_backend_metal_i = { /* .get_name = */ ggml_backend_metal_name, /* .free = */ ggml_backend_metal_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups + /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, + + // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal + // in any case, these docs seem relevant if we ever decide to implement it: + // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events /* .event_record = */ NULL, /* .event_wait = */ NULL, /* .optimize_graph = */ NULL, @@ -6514,8 +6710,20 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml GGML_UNUSED(dev); } +static int64_t get_op_batch_size(const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_MUL_MAT_ID: + return op->ne[1]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - return false; + const int min_batch_size = 32; + + return get_op_batch_size(op) >= min_batch_size; + //return false; GGML_UNUSED(dev); GGML_UNUSED(op);