summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Tsochantaris <ptsochantaris@icloud.com>2024-01-16 17:05:19 +0000
committerGitHub <noreply@github.com>2024-01-16 19:05:19 +0200
commit158f8c9e21302114bac3c646f80ea85b52ffa0bd (patch)
treee06093e1b9d6b5e99415df9c22954983049474f7
parent862f5e41ab1fdf12d6f59455aad3f5dd8258f805 (diff)
metal : localized logic in `ggml_metal_graph_compute` (#4924)
* Metal: Localized logic in `ggml_metal_graph_compute`, minor performance improvement * Whitespace * Collecting command buffer completions on single thread * Whitespace * Reduce diff noise
-rw-r--r--ggml-metal.h1
-rw-r--r--ggml-metal.m37
2 files changed, 17 insertions, 21 deletions
diff --git a/ggml-metal.h b/ggml-metal.h
index 8b0bfc5f..df83a180 100644
--- a/ggml-metal.h
+++ b/ggml-metal.h
@@ -27,7 +27,6 @@
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64
-#define GGML_METAL_MAX_COMMAND_BUFFERS 32
struct ggml_tensor;
struct ggml_cgraph;
diff --git a/ggml-metal.m b/ggml-metal.m
index c21dc465..a549e671 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -170,9 +170,6 @@ struct ggml_metal_context {
id<MTLCommandQueue> queue;
id<MTLLibrary> library;
- id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
- id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
-
dispatch_queue_t d_queue;
int n_buffers;
@@ -719,25 +716,25 @@ static bool ggml_metal_graph_compute(
@autoreleasepool {
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
-
- const int n_nodes = gf->n_nodes;
edesc.dispatchType = MTLDispatchTypeSerial;
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
+ const int n_nodes = gf->n_nodes;
const int n_cb = ctx->n_cb;
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
- for (int i = 0; i < n_cb; ++i) {
- ctx->command_buffers[i] = [ctx->queue commandBuffer];
+ id<MTLCommandBuffer> command_buffer_builder[n_cb];
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+ command_buffer_builder[cb_idx] = command_buffer;
// enqueue the command buffers in order to specify their execution order
- [ctx->command_buffers[i] enqueue];
-
- ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
+ [command_buffer enqueue];
}
+ const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
const int cb_idx = iter;
@@ -745,15 +742,13 @@ static bool ggml_metal_graph_compute(
size_t offs_src1 = 0;
size_t offs_dst = 0;
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
- id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
- for (int ind = node_start; ind < node_end; ++ind) {
- const int i = ind;
-
+ for (int i = node_start; i < node_end; ++i) {
if (i == -1) {
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
@@ -2249,12 +2244,14 @@ static bool ggml_metal_graph_compute(
[command_buffer commit];
});
- // check status of command buffers
+ // Wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
- for (int i = 0; i < n_cb; i++) {
- [ctx->command_buffers[i] waitUntilCompleted];
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
+ for (int i = 0; i < n_cb; ++i) {
+ id<MTLCommandBuffer> command_buffer = command_buffers[i];
+ [command_buffer waitUntilCompleted];
+
+ MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
return false;