summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Mousseau <jmousseau@users.noreply.github.com>2024-01-29 01:22:23 -0800
committerGeorgi Gerganov <ggerganov@gmail.com>2024-01-30 16:20:25 +0200
commit5f14ee0b0cd06f1c4790e6123df4b38ace637e88 (patch)
treea2e37cb93252242e2d62c83f627afd02fca9c500
parent8e14e3ddb3744566aef7bc0fa734180e47ae6bdf (diff)
metal : add debug capture backend function (ggml/694)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r--ggml-metal.h3
-rw-r--r--ggml-metal.m40
2 files changed, 37 insertions, 6 deletions
diff --git a/ggml-metal.h b/ggml-metal.h
index df83a180..a5c54218 100644
--- a/ggml-metal.h
+++ b/ggml-metal.h
@@ -57,6 +57,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
+// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
+GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
+
#ifdef __cplusplus
}
#endif
diff --git a/ggml-metal.m b/ggml-metal.m
index 1b02493f..7e148b6b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -168,6 +168,8 @@ struct ggml_metal_context {
bool support_simdgroup_reduction;
bool support_simdgroup_mm;
+
+ bool should_capture_next_compute;
};
// MSL code
@@ -687,6 +689,20 @@ static bool ggml_metal_graph_compute(
const int n_cb = ctx->n_cb;
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
+ const bool should_capture = ctx->should_capture_next_compute;
+ if (should_capture) {
+ ctx->should_capture_next_compute = false;
+
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
+ descriptor.captureObject = ctx->queue;
+
+ NSError * error = nil;
+ if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
+ GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
+ GGML_ASSERT(!"capture failed");
+ }
+ }
+
id<MTLCommandBuffer> command_buffer_builder[n_cb];
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
@@ -695,6 +711,7 @@ static bool ggml_metal_graph_compute(
// enqueue the command buffers in order to specify their execution order
[command_buffer enqueue];
}
+
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
@@ -741,9 +758,9 @@ static bool ggml_metal_graph_compute(
GGML_ASSERT(!"unsupported op");
}
-#ifndef GGML_METAL_NDEBUG
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
-#endif
+ if (should_capture) {
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
+ }
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -2218,9 +2235,9 @@ static bool ggml_metal_graph_compute(
}
}
-#ifndef GGML_METAL_NDEBUG
- [encoder popDebugGroup];
-#endif
+ if (should_capture) {
+ [encoder popDebugGroup];
+ }
}
[encoder endEncoding];
@@ -2242,6 +2259,10 @@ static bool ggml_metal_graph_compute(
}
}
+ if (should_capture) {
+ [[MTLCaptureManager sharedCaptureManager] stopCapture];
+ }
+
return true;
}
@@ -2613,6 +2634,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
}
+void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
+ GGML_ASSERT(ggml_backend_is_metal(backend));
+
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+ ctx->should_capture_next_compute = true;
+}
+
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {