diff options
author | Paul Tsochantaris <ptsochantaris@icloud.com> | 2024-01-16 17:05:19 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-16 19:05:19 +0200 |
commit | 158f8c9e21302114bac3c646f80ea85b52ffa0bd (patch) | |
tree | e06093e1b9d6b5e99415df9c22954983049474f7 | |
parent | 862f5e41ab1fdf12d6f59455aad3f5dd8258f805 (diff) |
metal : localized logic in `ggml_metal_graph_compute` (#4924)
* Metal: Localized logic in `ggml_metal_graph_compute`, minor performance improvement
* Whitespace
* Collecting command buffer completions on single thread
* Whitespace
* Reduce diff noise
-rw-r--r-- | ggml-metal.h | 1 | ||||
-rw-r--r-- | ggml-metal.m | 37 |
2 files changed, 17 insertions, 21 deletions
diff --git a/ggml-metal.h b/ggml-metal.h index 8b0bfc5f..df83a180 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -27,7 +27,6 @@ // max memory buffers that can be mapped to the device #define GGML_METAL_MAX_BUFFERS 64 -#define GGML_METAL_MAX_COMMAND_BUFFERS 32 struct ggml_tensor; struct ggml_cgraph; diff --git a/ggml-metal.m b/ggml-metal.m index c21dc465..a549e671 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -170,9 +170,6 @@ struct ggml_metal_context { id<MTLCommandQueue> queue; id<MTLLibrary> library; - id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS]; - id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS]; - dispatch_queue_t d_queue; int n_buffers; @@ -719,25 +716,25 @@ static bool ggml_metal_graph_compute( @autoreleasepool { MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - - const int n_nodes = gf->n_nodes; edesc.dispatchType = MTLDispatchTypeSerial; // create multiple command buffers and enqueue them // then, we encode the graph into the command buffers in parallel + const int n_nodes = gf->n_nodes; const int n_cb = ctx->n_cb; + const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - for (int i = 0; i < n_cb; ++i) { - ctx->command_buffers[i] = [ctx->queue commandBuffer]; + id<MTLCommandBuffer> command_buffer_builder[n_cb]; + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + command_buffer_builder[cb_idx] = command_buffer; // enqueue the command buffers in order to specify their execution order - [ctx->command_buffers[i] enqueue]; - - ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; + [command_buffer enqueue]; } + const id<MTLCommandBuffer> *command_buffers = command_buffer_builder; - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { const int cb_idx = iter; @@ -745,15 +742,13 @@ static bool ggml_metal_graph_compute( size_t offs_src1 = 0; size_t offs_dst = 0; - id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx]; - id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx]; + id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx]; + id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); - for (int ind = node_start; ind < node_end; ++ind) { - const int i = ind; - + for (int i = node_start; i < node_end; ++i) { if (i == -1) { [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; continue; @@ -2249,12 +2244,14 @@ static bool ggml_metal_graph_compute( [command_buffer commit]; }); - // check status of command buffers + // Wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) - for (int i = 0; i < n_cb; i++) { - [ctx->command_buffers[i] waitUntilCompleted]; - MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; + for (int i = 0; i < n_cb; ++i) { + id<MTLCommandBuffer> command_buffer = command_buffers[i]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; if (status != MTLCommandBufferStatusCompleted) { GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); return false; |