summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-metal.m')
-rw-r--r--ggml/src/ggml-metal.m117
1 files changed, 75 insertions, 42 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 7d592c22..292f9ac7 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -260,7 +260,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COUNT
};
-struct ggml_metal_context {
+struct ggml_backend_metal_context {
int n_cb;
id<MTLDevice> device;
@@ -274,6 +274,10 @@ struct ggml_metal_context {
bool support_simdgroup_mm;
bool should_capture_next_compute;
+
+ // abort ggml_metal_graph_compute if callback returns true
+ ggml_abort_callback abort_callback;
+ void * abort_callback_data;
};
// MSL code
@@ -339,7 +343,7 @@ static void * ggml_metal_host_malloc(size_t n) {
return data;
}
-static struct ggml_metal_context * ggml_metal_init(int n_cb) {
+static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@@ -356,7 +360,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
// Configure context
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+ struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
ctx->device = device;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
ctx->queue = [ctx->device newCommandQueue];
@@ -761,7 +765,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
return ctx;
}
-static void ggml_metal_free(struct ggml_metal_context * ctx) {
+static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
@@ -827,7 +831,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
return nil;
}
-static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
+static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
for (size_t i = 0, n = 3; i < n; ++i) {
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
return false;
@@ -938,7 +942,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
}
static enum ggml_status ggml_metal_graph_compute(
- struct ggml_metal_context * ctx,
+ struct ggml_backend_metal_context * ctx,
struct ggml_cgraph * gf) {
@autoreleasepool {
@@ -962,7 +966,7 @@ static enum ggml_status ggml_metal_graph_compute(
NSError * error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
- GGML_ASSERT(!"capture failed");
+ GGML_ABORT("capture failed");
}
}
@@ -971,8 +975,11 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
command_buffer_builder[cb_idx] = command_buffer;
- // enqueue the command buffers in order to specify their execution order
- [command_buffer enqueue];
+ // always enqueue the first two command buffers
+ // enqueue all of the command buffers if we don't need to abort
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
+ [command_buffer enqueue];
+ }
}
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
@@ -1024,7 +1031,7 @@ static enum ggml_status ggml_metal_graph_compute(
if (!ggml_metal_supports_op(ctx, dst)) {
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
- GGML_ASSERT(!"unsupported op");
+ GGML_ABORT("unsupported op");
}
if (should_capture) {
@@ -1207,7 +1214,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
- default: GGML_ASSERT(false);
+ default: GGML_ABORT("fatal error");
}
bcast_row = true;
@@ -1216,7 +1223,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
- default: GGML_ASSERT(false);
+ default: GGML_ABORT("fatal error");
}
}
@@ -1270,7 +1277,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
- default: GGML_ASSERT(false);
+ default: GGML_ABORT("fatal error");
}
[encoder setComputePipelineState:pipeline];
@@ -1526,7 +1533,7 @@ static enum ggml_status ggml_metal_graph_compute(
default:
{
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
} break;
case GGML_OP_SQR:
@@ -1756,7 +1763,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break;
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break;
case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break;
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
+ default: GGML_ABORT("MUL MAT-MAT not implemented");
}
[encoder setComputePipelineState:pipeline];
@@ -1977,7 +1984,7 @@ static enum ggml_status ggml_metal_graph_compute(
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
- GGML_ASSERT(false && "not implemented");
+ GGML_ABORT("not implemented");
}
};
@@ -2117,7 +2124,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break;
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break;
case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break;
- default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
+ default: GGML_ABORT("MUL_MAT_ID not implemented");
}
[encoder setComputePipelineState:pipeline];
@@ -2332,7 +2339,7 @@ static enum ggml_status ggml_metal_graph_compute(
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
- GGML_ASSERT(false && "not implemented");
+ GGML_ABORT("not implemented");
}
};
@@ -2443,7 +2450,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break;
case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
+ default: GGML_ABORT("not implemented");
}
[encoder setComputePipelineState:pipeline];
@@ -2494,10 +2501,8 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous(src0));
- //float eps;
- //memcpy(&eps, dst->op_params, sizeof(float));
-
- const float eps = 1e-6f; // TODO: temporarily hardcoded
+ float eps;
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
@@ -2581,13 +2586,13 @@ static enum ggml_status ggml_metal_graph_compute(
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
- default: GGML_ASSERT(false);
+ default: GGML_ABORT("fatal error");
};
} else {
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
- default: GGML_ASSERT(false);
+ default: GGML_ABORT("fatal error");
};
}
@@ -2664,7 +2669,7 @@ static enum ggml_status ggml_metal_graph_compute(
switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
- default: GGML_ASSERT(false);
+ default: GGML_ABORT("fatal error");
};
[encoder setComputePipelineState:pipeline];
@@ -2821,7 +2826,7 @@ static enum ggml_status ggml_metal_graph_compute(
switch (order) {
case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
- default: GGML_ASSERT(false);
+ default: GGML_ABORT("fatal error");
};
[encoder setComputePipelineState:pipeline];
@@ -2910,7 +2915,7 @@ static enum ggml_status ggml_metal_graph_compute(
{
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
- GGML_ASSERT(false && "add template specialization for this size");
+ GGML_ABORT("add template specialization for this size");
}
}
} else {
@@ -2923,7 +2928,7 @@ static enum ggml_status ggml_metal_graph_compute(
{
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
- GGML_ASSERT(false && "add template specialization for this size");
+ GGML_ABORT("add template specialization for this size");
}
}
}
@@ -3044,7 +3049,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
+ default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_F16:
@@ -3052,10 +3057,10 @@ static enum ggml_status ggml_metal_graph_compute(
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
+ default: GGML_ABORT("not implemented");
};
} break;
- default: GGML_ASSERT(false && "not implemented");
+ default: GGML_ABORT("not implemented");
}
[encoder setComputePipelineState:pipeline];
@@ -3083,7 +3088,7 @@ static enum ggml_status ggml_metal_graph_compute(
default:
{
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ASSERT(false);
+ GGML_ABORT("fatal error");
}
}
@@ -3094,7 +3099,9 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder endEncoding];
- [command_buffer commit];
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
+ [command_buffer commit];
+ }
});
// Wait for completion and check status of each command buffer
@@ -3114,6 +3121,23 @@ static enum ggml_status ggml_metal_graph_compute(
return GGML_STATUS_FAILED;
}
+
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
+ if (!next_buffer) {
+ continue;
+ }
+
+ bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
+ if (next_queued) {
+ continue;
+ }
+
+ if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
+ GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
+ return GGML_STATUS_ABORTED;
+ }
+
+ [next_buffer commit];
}
if (should_capture) {
@@ -3417,7 +3441,7 @@ GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
}
GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+ struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
ggml_metal_free(ctx);
free(backend);
}
@@ -3429,13 +3453,13 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
}
GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
- struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
+ struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
return ggml_metal_graph_compute(metal_ctx, cgraph);
}
GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
- struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
+ struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
return ggml_metal_supports_op(metal_ctx, op);
}
@@ -3480,9 +3504,9 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
}
ggml_backend_t ggml_backend_metal_init(void) {
- struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
-
+ struct ggml_backend_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
if (ctx == NULL) {
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
return NULL;
}
@@ -3504,15 +3528,24 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+ struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
}
+void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
+ GGML_ASSERT(ggml_backend_is_metal(backend));
+
+ struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+
+ ctx->abort_callback = abort_callback;
+ ctx->abort_callback_data = user_data;
+}
+
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
GGML_ASSERT(ggml_backend_is_metal(backend));
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+ struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
}
@@ -3520,7 +3553,7 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
GGML_ASSERT(ggml_backend_is_metal(backend));
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+ struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
ctx->should_capture_next_compute = true;
}