summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-metal.m3988
-rw-r--r--ggml/src/ggml-metal.metal150
2 files changed, 2107 insertions, 2031 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 0498be1f..aa50d448 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -297,8 +297,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COUNT
};
+#define GGML_METAL_MAX_COMMAND_BUFFERS 8
+
struct ggml_backend_metal_context {
- int n_cb;
id<MTLDevice> device;
id<MTLCommandQueue> queue;
@@ -307,6 +308,26 @@ struct ggml_backend_metal_context {
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
+ // capture state
+ bool capture_next_compute;
+ bool capture_started;
+
+ id<MTLCaptureScope> capture_scope;
+
+ // command buffer state
+ int n_cb; // number of extra threads used to submit the command buffers
+ int n_nodes_0; // number of nodes submitted by the main thread
+ int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
+ int n_nodes_per_cb;
+
+ struct ggml_cgraph * gf;
+
+ // the callback given to the thread pool
+ void (^encode_async)(size_t ith);
+
+ // n_cb command buffers + 1 used by the main thread
+ id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
+
bool support_simdgroup_reduction;
bool support_simdgroup_mm;
@@ -373,7 +394,6 @@ static void * ggml_metal_host_malloc(size_t n) {
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
if (result != 0) {
GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
- return NULL;
}
#endif
@@ -533,6 +553,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
ctx->should_capture_next_compute = false;
+ ctx->capture_started = false;
+ ctx->capture_scope = nil;
+
+ ctx->gf = nil;
+ ctx->encode_async = nil;
+ for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
+ ctx->command_buffers[i] = nil;
+ }
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
@@ -846,6 +874,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
[ctx->kernels[i].pipeline release];
}
+ Block_release(ctx->encode_async);
+
[ctx->queue release];
[ctx->device release];
@@ -1027,934 +1057,871 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
}
}
-static enum ggml_status ggml_metal_graph_compute(
+static void ggml_metal_encode_node(
struct ggml_backend_metal_context * ctx,
- struct ggml_cgraph * gf) {
-
- @autoreleasepool {
- MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
- edesc.dispatchType = MTLDispatchTypeSerial;
+ struct ggml_tensor * node,
+ id<MTLComputeCommandEncoder> encoder) {
- // create multiple command buffers and enqueue them
- // then, we encode the graph into the command buffers in parallel
- const int n_nodes = gf->n_nodes;
- const int n_cb = ctx->n_cb;
- const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
+ struct ggml_tensor * src0 = node->src[0];
+ struct ggml_tensor * src1 = node->src[1];
+ struct ggml_tensor * src2 = node->src[2];
+ struct ggml_tensor * dst = node;
- const bool should_capture = ctx->should_capture_next_compute;
- if (should_capture) {
- ctx->should_capture_next_compute = false;
+ if (ggml_is_empty(dst)) {
+ return;
+ }
- MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
- descriptor.captureObject = ctx->queue;
+ switch (dst->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE: return; // noop
+ default: break;
+ }
- NSError * error = nil;
- if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
- GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
- GGML_ABORT("capture failed");
- }
+ if (!ggml_metal_supports_op(ctx, dst)) {
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+ GGML_ABORT("unsupported op");
}
- id<MTLCommandBuffer> command_buffer_builder[n_cb];
- for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
- command_buffer_builder[cb_idx] = command_buffer;
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
+
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
+
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
+
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
+
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
+
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
+
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
+
+ size_t offs_src0 = 0;
+ size_t offs_src1 = 0;
+ size_t offs_src2 = 0;
+ size_t offs_dst = 0;
+
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
+
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+ //if (src0) {
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+ // ggml_is_contiguous(src0), src0->name);
+ //}
+ //if (src1) {
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+ // ggml_is_contiguous(src1), src1->name);
+ //}
+ //if (dst) {
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
+ // dst->name);
+ //}
+
+ switch (dst->op) {
+ case GGML_OP_CONCAT:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
- // always enqueue the first two command buffers
- // enqueue all of the command buffers if we don't need to abort
- if (cb_idx < 2 || ctx->abort_callback == NULL) {
- [command_buffer enqueue];
- }
- }
+ const size_t offs = 0;
- const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
+ bool bcast_row = false;
- dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
- const int cb_idx = iter;
+ int64_t nb = ne00; // used by the "row" kernels
- size_t offs_src0 = 0;
- size_t offs_src1 = 0;
- size_t offs_src2 = 0;
- size_t offs_dst = 0;
+ id<MTLComputePipelineState> pipeline = nil;
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+ if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) {
+ float scale;
+ memcpy(&scale, src1->data, sizeof(float));
+ //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
- const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
+ int64_t n = ggml_nelements(dst);
- for (int i = node_start; i < node_end; ++i) {
- if (i == -1) {
- [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
- continue;
- }
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+ }
- //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
- struct ggml_tensor * src0 = gf->nodes[i]->src[0];
- struct ggml_tensor * src1 = gf->nodes[i]->src[1];
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
- struct ggml_tensor * dst = gf->nodes[i];
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ break;
+ }
+ else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) &&
+ dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] &&
+ dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] &&
+ dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] &&
+ dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) {
+
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break;
+ default: GGML_ASSERT(false);
+ }
- if (ggml_is_empty(dst)) {
- continue;
- }
+ int64_t n = ggml_nelements(dst)/4;
- switch (dst->op) {
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- {
- // noop -> next node
- } continue;
- default:
- {
- } break;
- }
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- if (!ggml_metal_supports_op(ctx, dst)) {
- GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
- GGML_ABORT("unsupported op");
- }
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ break;
+ }
+ else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ // src1 is a row
+ GGML_ASSERT(ne11 == 1);
+
+ nb = ne00 / 4;
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ }
- if (should_capture) {
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
- }
+ bcast_row = true;
+ } else {
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ }
+ }
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
-
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
-
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
- const int64_t ne13 = src1 ? src1->ne[3] : 0;
-
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
- const uint64_t nb13 = src1 ? src1->nb[3] : 0;
-
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
- const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
-
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
-
- const int64_t ne0 = dst ? dst->ne[0] : 0;
- const int64_t ne1 = dst ? dst->ne[1] : 0;
- const int64_t ne2 = dst ? dst->ne[2] : 0;
- const int64_t ne3 = dst ? dst->ne[3] : 0;
-
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
-
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
-
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
- id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
-
- //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
- //if (src0) {
- // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
- // ggml_is_contiguous(src0), src0->name);
- //}
- //if (src1) {
- // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
- // ggml_is_contiguous(src1), src1->name);
- //}
- //if (dst) {
- // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
- // dst->name);
- //}
-
- switch (dst->op) {
- case GGML_OP_CONCAT:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
-
- const int32_t dim = ((int32_t *) dst->op_params)[0];
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
- [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
-
- const int nth = MIN(1024, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ADD:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- {
- GGML_ASSERT(src0t == GGML_TYPE_F32);
- GGML_ASSERT(src1t == GGML_TYPE_F32);
-
- const size_t offs = 0;
-
- bool bcast_row = false;
-
- int64_t nb = ne00; // used by the "row" kernels
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) {
- float scale;
- memcpy(&scale, src1->data, sizeof(float));
- //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale);
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
-
- int64_t n = ggml_nelements(dst);
-
- if (n % 4 == 0) {
- n /= 4;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
- }
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
+
+ if (bcast_row) {
+ const int64_t n = ggml_nelements(dst)/4;
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } else {
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ }
+ } break;
+ case GGML_OP_MULTI_ADD:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
+ GGML_ASSERT(ne02 == 1 && ne03 == 1);
+ GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ int n_expert = dst->op_params[0];
+ GGML_ASSERT(n_expert >= 2);
+
+ id<MTLComputePipelineState> pipeline = nil;
+ int64_t n = ne0*ne1;
+ if (ne0%4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline;
+ }
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
+ [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ id<MTLComputePipelineState> pipeline;
+
+ switch (src0t) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ACC:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
+ const size_t offs = ((int32_t *) dst->op_params)[3];
+
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+ if (!inplace) {
+ // run a separete kernel to cpy src->dst
+ // not sure how to avoid this
+ // TODO: make a simpler cpy_bytes kernel
+
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ }
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- break;
- }
- else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) &&
- dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] &&
- dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] &&
- dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] &&
- dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) {
-
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break;
- default: GGML_ASSERT(false);
- }
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_SCALE:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
- int64_t n = ggml_nelements(dst)/4;
+ float scale;
+ memcpy(&scale, dst->op_params, sizeof(scale));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ int64_t n = ggml_nelements(dst);
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- break;
- }
- else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- // src1 is a row
- GGML_ASSERT(ne11 == 1);
-
- nb = ne00 / 4;
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
+ id<MTLComputePipelineState> pipeline = nil;
- bcast_row = true;
- } else {
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
- }
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+ }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
-
- if (bcast_row) {
- const int64_t n = ggml_nelements(dst)/4;
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } else {
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
- } break;
- case GGML_OP_MULTI_ADD:
- {
- GGML_ASSERT(src0t == GGML_TYPE_F32);
- GGML_ASSERT(dstt == GGML_TYPE_F32);
- GGML_ASSERT(ne02 == 1 && ne03 == 1);
- GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float));
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
-
- int n_expert = dst->op_params[0];
- GGML_ASSERT(n_expert >= 2);
-
- id<MTLComputePipelineState> pipeline = nil;
- int64_t n = ne0*ne1;
- if (ne0%4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline;
- n /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline;
- }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
- [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_REPEAT:
- {
- id<MTLComputePipelineState> pipeline;
-
- switch (src0t) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
- case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SOFTCAP:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ACC:
- {
- GGML_ASSERT(src0t == GGML_TYPE_F32);
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- GGML_ASSERT(dstt == GGML_TYPE_F32);
-
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
-
- const size_t pnb1 = ((int32_t *) dst->op_params)[0];
- const size_t pnb2 = ((int32_t *) dst->op_params)[1];
- const size_t pnb3 = ((int32_t *) dst->op_params)[2];
- const size_t offs = ((int32_t *) dst->op_params)[3];
-
- const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
-
- if (!inplace) {
- // run a separete kernel to cpy src->dst
- // not sure how to avoid this
- // TODO: make a simpler cpy_bytes kernel
-
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
+ float scales[2];
+ memcpy(scales, dst->op_params, sizeof(scales));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- }
+ int64_t n = ggml_nelements(dst);
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_SCALE:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- float scale;
- memcpy(&scale, dst->op_params, sizeof(scale));
-
- int64_t n = ggml_nelements(dst);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (n % 4 == 0) {
- n /= 4;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
- }
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline;
+ }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2];
+ [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3];
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SOFTCAP:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_CLAMP:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
- float scales[2];
- memcpy(scales, dst->op_params, sizeof(scales));
+ float min;
+ float max;
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
- int64_t n = ggml_nelements(dst);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
- id<MTLComputePipelineState> pipeline = nil;
+ const int64_t n = ggml_nelements(dst);
- if (n % 4 == 0) {
- n /= 4;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline;
- }
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(node)) {
+ // we are not taking into account the strides, so for now require contiguous tensors
+ GGML_ASSERT(ggml_is_contiguous(src0));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2];
- [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_CLAMP:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
-
- float min;
- float max;
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_UNARY:
- switch (ggml_get_unary_op(gf->nodes[i])) {
- // we are not taking into account the strides, so for now require contiguous tensors
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- case GGML_UNARY_OP_TANH:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
+ case GGML_UNARY_OP_TANH:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst);
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_RELU:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_RELU:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst);
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_SIGMOID:
- {
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_SIGMOID:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst);
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_GELU:
- {
- int64_t n = ggml_nelements(dst);
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU:
+ {
+ int64_t n = ggml_nelements(dst);
- id<MTLComputePipelineState> pipeline = nil;
+ id<MTLComputePipelineState> pipeline = nil;
- if (n % 4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
- n /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
- }
+ if (n % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+ }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_GELU_QUICK:
- {
- int64_t n = ggml_nelements(dst);
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ {
+ int64_t n = ggml_nelements(dst);
- id<MTLComputePipelineState> pipeline = nil;
+ id<MTLComputePipelineState> pipeline = nil;
- if (n % 4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
- n /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
- }
+ if (n % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+ }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_SILU:
- {
- int64_t n = ggml_nelements(dst);
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_SILU:
+ {
+ int64_t n = ggml_nelements(dst);
- id<MTLComputePipelineState> pipeline = nil;
+ id<MTLComputePipelineState> pipeline = nil;
- if (n % 4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
- n /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
- }
+ if (n % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+ }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_UNARY_OP_SWIGLU:
- {
- int64_t n = ggml_nelements(dst);
- GGML_ASSERT(ne0 == src0->ne[0]/2);
-
- id<MTLComputePipelineState> pipeline = nil;
-
- uint32_t n_per_row = ne0;
- uint32_t stride = src0->nb[1]/sizeof(float);
-
- if (ne0 % 4 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline;
- n /= 4;
- n_per_row /= 4;
- stride /= 4;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
- }
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2];
- [encoder setBytes:&stride length:sizeof(stride) atIndex:3];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- default:
- {
- GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_OP_FUSED_MUL_UNARY:
- {
- int64_t n = ggml_nelements(dst);
- enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
- id<MTLComputePipelineState> pipeline = nil;
- if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) {
- pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline
- : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline;
- n /= 4;
- } else {
- pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline
- : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline
- : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline;
- }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SQR:
- {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
-
- const int64_t n = ggml_nelements(dst);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SUM_ROWS:
- {
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_SOFT_MAX:
- {
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
-
- int nth = 32; // SIMD width
-
- id<MTLComputePipelineState> pipeline = nil;
-
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
- if (ne00%4 == 0) {
- while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
- nth *= 2;
- }
- if (use_f16) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
- }
- } else {
- while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
- nth *= 2;
- }
- if (use_f16) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
- }
- }
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_SWIGLU:
+ {
+ int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(ne0 == src0->ne[0]/2);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ uint32_t n_per_row = ne0;
+ uint32_t stride = src0->nb[1]/sizeof(float);
+
+ if (ne0 % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline;
+ n /= 4;
+ n_per_row /= 4;
+ stride /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
+ }
- float scale;
- float max_bias;
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2];
+ [encoder setBytes:&stride length:sizeof(stride) atIndex:3];
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_WARN("%s: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op));
+ GGML_ABORT("fatal error");
+ }
+ } break;
+ case GGML_OP_FUSED_MUL_UNARY:
+ {
+ int64_t n = ggml_nelements(dst);
+ enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0];
+ id<MTLComputePipelineState> pipeline = nil;
+ if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) {
+ pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline
+ : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline
+ : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline
+ : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline;
+ }
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SQR:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
- const int64_t nrows_x = ggml_nrows(src0);
- const int64_t nrows_y = src0->ne[1];
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
- const uint32_t n_head = nrows_x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+ const int64_t n = ggml_nelements(dst);
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- if (id_src1) {
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_SOFT_CAP_MAX:
- {
- GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
-
- int nth = 32; // SIMD width
-
- id<MTLComputePipelineState> pipeline = nil;
-
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
-
- if (ne00%4 == 0) {
- while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
- nth *= 2;
- }
- if (use_f16) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline;
- }
- } else {
- while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
- nth *= 2;
- }
- if (use_f16) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline;
- }
- }
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SUM_ROWS:
+ {
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+
+ int nth = 32; // SIMD width
+
+ id<MTLComputePipelineState> pipeline = nil;
- float scale;
- float max_bias;
- float s_before;
- float s_after;
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
- memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before));
- memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after));
+ if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
+ }
+ } else {
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
+ }
+ }
- const int64_t nrows_x = ggml_nrows(src0);
- const int64_t nrows_y = src0->ne[1];
+ float scale;
+ float max_bias;
- const uint32_t n_head = nrows_x/nrows_y;
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- if (id_src1) {
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
- [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10];
- [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_DIAG_MASK_INF:
- {
- const int n_past = ((int32_t *)(dst->op_params))[0];
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (ne00%8 == 0) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
- } else {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
- }
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- if (ne00%8 == 0) {
- [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- }
- else {
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- }
- } break;
- case GGML_OP_MUL_MAT:
- {
- GGML_ASSERT(ne00 == ne10);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ if (id_src1) {
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_SOFT_CAP_MAX:
+ {
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+
+ int nth = 32; // SIMD width
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
- GGML_ASSERT(ne12 % ne02 == 0);
- GGML_ASSERT(ne13 % ne03 == 0);
+ if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline;
+ }
+ } else {
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline;
+ }
+ }
- const uint r2 = ne12/ne02;
- const uint r3 = ne13/ne03;
+ float scale;
+ float max_bias;
+ float s_before;
+ float s_after;
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
- // to the matrix-vector kernel
- int ne11_mm_min = 1;
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+ memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before));
+ memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after));
+
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ if (id_src1) {
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
+ [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10];
+ [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ const int n_past = ((int32_t *)(dst->op_params))[0];
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (ne00%8 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
+
+ if (ne00%8 == 0) {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ else {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ GGML_ASSERT(ne00 == ne10);
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ const uint r2 = ne12/ne02;
+ const uint r3 = ne13/ne03;
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ int ne11_mm_min = 4;
#if 0
- // the numbers below are measured on M2 Ultra for 7B and 13B models
- // these numbers do not translate to other devices or model sizes
- // TODO: need to find a better approach
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
+ // these numbers do not translate to other devices or model sizes
+ // TODO: need to find a better approach
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
switch (src0t) {
case GGML_TYPE_F16: ne11_mm_min = 2; break;
@@ -1976,11 +1943,11 @@ static enum ggml_status ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- !ggml_is_transposed(src0) &&
- !ggml_is_transposed(src1) &&
- src1t == GGML_TYPE_F32 &&
- ne00 % 32 == 0 && ne00 >= 64 &&
- (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
+ !ggml_is_transposed(src0) &&
+ !ggml_is_transposed(src1) &&
+ src1t == GGML_TYPE_F32 &&
+ ne00 % 32 == 0 && ne00 >= 64 &&
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// some Metal matrix data types require aligned pointers
@@ -2305,9 +2272,9 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
- src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) {
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
+ src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) {
@@ -2326,8 +2293,8 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
- src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
- src0t == GGML_TYPE_IQ4_KSS) {
+ src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
+ src0t == GGML_TYPE_IQ4_KSS) {
const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2348,1179 +2315,1254 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
- } break;
- case GGML_OP_MUL_MAT_ID:
- {
- const int n_as = src0->ne[2];
+ } break;
+ case GGML_OP_MUL_MAT_ID:
+ {
+ const int n_as = src0->ne[2];
+
+ // src2 = ids
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
+
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
+
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ // ne20 = n_used_experts
+ // ne21 = n_rows
+ const int dst_rows = ne20*ne21;
+ const int dst_rows_min = n_as;
+ //const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength/2 - 8192)/4;
+ const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 8192)/4;
+
+ // max size of the rowids array in the kernel shared buffer
+ //GGML_ASSERT(dst_rows <= dst_rows_max);
+
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+ // !!!
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+ // indirect matrix multiplication
+ // !!!
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ ne00 % 32 == 0 && ne00 >= 64 &&
+ dst_rows > dst_rows_min &&
+ dst_rows <= dst_rows_max) {
+
+ // some Metal matrix data types require aligned pointers
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+ switch (src0->type) {
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
+ default: break;
+ }
- // src2 = ids
- const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break;
+ default: GGML_ABORT("MUL_MAT_ID not implemented");
+ }
- GGML_ASSERT(src2t == GGML_TYPE_I32);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(1, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ // use custom matrix x vector kernel
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q6_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 32;
+ nth1 = 2;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ3_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ3_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_M:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_NL:
+ {
+ nth0 = 32;
+ nth1 = 2;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_XS:
+ {
+ nth0 = 32;
+ nth1 = 2;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_KS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_KSS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_K:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_KS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ3_K:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_K:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ5_K:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ6_K:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline;
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
+ GGML_ABORT("not implemented");
+ }
+ };
- GGML_ASSERT(!ggml_is_transposed(src0));
- GGML_ASSERT(!ggml_is_transposed(src1));
+ if (ggml_is_quantized(src0t)) {
+ GGML_ASSERT(ne00 >= nth0*nth1);
+ }
- GGML_ASSERT(src1t == GGML_TYPE_F32);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
+
+ const int64_t _ne1 = 1;
+ const int tgz = dst_rows;
+
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 ||
+ src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float);
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
+ src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
+ src0t == GGML_TYPE_IQ4_KSS) {
+ const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float);
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q3_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ }
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
+ case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
+ case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break;
+ case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break;
+ case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break;
+ case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break;
+ case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break;
+ case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break;
+ case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break;
+ case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break;
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ }
- // find the break-even point where the matrix-matrix kernel becomes more efficient compared
- // to the matrix-vector kernel
- // ne20 = n_used_experts
- // ne21 = n_rows
- const int dst_rows = ne20*ne21;
- const int dst_rows_min = n_as;
- const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
+ } break;
+ case GGML_OP_RMS_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
- // max size of the rowids array in the kernel shared buffer
- GGML_ASSERT(dst_rows <= dst_rows_max);
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- // !!!
- // TODO: for now, always use mat-vec kernels until we figure out how to improve the
- // indirect matrix multiplication
- // !!!
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne00 % 32 == 0 && ne00 >= 64 &&
- dst_rows > dst_rows_min) {
+ int nth = 32; // SIMD width
- // some Metal matrix data types require aligned pointers
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
- switch (src0->type) {
- case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
- case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
- default: break;
- }
+ while (nth < ne00/4 && nth < 1024) {
+ nth *= 2;
+ }
- id<MTLComputePipelineState> pipeline = nil;
+ nth = MIN(nth, ne00/4);
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
- case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break;
- case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break;
- case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break;
- default: GGML_ABORT("MUL_MAT_ID not implemented");
- }
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
-
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
- } else {
- int nth0 = 32;
- int nth1 = 1;
- int nrows = 1;
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- id<MTLComputePipelineState> pipeline = nil;
+ const int64_t nrows = ggml_nrows(src0);
- // use custom matrix x vector kernel
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
- } break;
- case GGML_TYPE_F16:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- nth0 = 32;
- nth1 = 1;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
- } break;
- case GGML_TYPE_BF16:
- {
- GGML_ASSERT(src1t == GGML_TYPE_F32);
- nth0 = 32;
- nth1 = 1;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_1:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
- } break;
- case GGML_TYPE_Q6_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q8_0:
- {
- nth0 = 8;
- nth1 = 8;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
- } break;
- case GGML_TYPE_Q2_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q3_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q4_K:
- {
- nth0 = 4; //1;
- nth1 = 8; //32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q5_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
- } break;
- case GGML_TYPE_Q6_K:
- {
- nth0 = 2;
- nth1 = 32;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ3_XXS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ3_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ1_S:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
- } break;
- case GGML_TYPE_IQ1_M:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
- } break;
- case GGML_TYPE_IQ1_BN:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_BN:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_NL:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_XS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_KS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_KSS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_K:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ2_KS:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline;
- } break;
- case GGML_TYPE_IQ3_K:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ4_K:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ5_K:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline;
- } break;
- case GGML_TYPE_IQ6_K:
- {
- nth0 = 4;
- nth1 = 16;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline;
- } break;
- default:
- {
- GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
- GGML_ABORT("not implemented");
- }
- };
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_FUSED_RMS_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(src1->ne[0] == src0->ne[0]);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_nrows(src1) == 1);
- if (ggml_is_quantized(src0t)) {
- GGML_ASSERT(ne00 >= nth0*nth1);
- }
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
-
- const int64_t _ne1 = 1;
- const int tgz = dst_rows;
+ int nth = 32; // SIMD width
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 ||
- src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) {
- const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float);
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
- const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
- const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K ||
- src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS||
- src0t == GGML_TYPE_IQ4_KSS) {
- const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float);
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q3_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- else if (src0t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- } else {
- const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
- }
- }
- } break;
- case GGML_OP_GET_ROWS:
- {
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
- case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
- case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
- case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
- case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break;
- case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break;
- case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break;
- case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break;
- case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break;
- case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break;
- case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break;
- case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break;
- case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
- default: GGML_ABORT("not implemented");
- }
+ while (nth < ne00/4 && nth < 1024) {
+ nth *= 2;
+ }
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
- } break;
- case GGML_OP_RMS_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous_1(src0));
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- int nth = 32; // SIMD width
-
- while (nth < ne00/4 && nth < 1024) {
- nth *= 2;
- }
+ nth = MIN(nth, ne00/4);
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:5];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- const int64_t nrows = ggml_nrows(src0);
+ const int64_t nrows = ggml_nrows(src0);
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_FUSED_RMS_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous_1(src0));
- GGML_ASSERT(src1->ne[0] == src0->ne[0]);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_nrows(src1) == 1);
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ggml_is_contiguous(src0));
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
+ float eps;
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
- int nth = 32; // SIMD width
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
- while (nth < ne00/4 && nth < 1024) {
- nth *= 2;
- }
+ int nth = 32; // SIMD width
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&eps length:sizeof( float) atIndex:5];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- const int64_t nrows = ggml_nrows(src0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_GROUP_NORM:
- {
- GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- float eps;
- memcpy(&eps, dst->op_params + 1, sizeof(float));
-
- const int32_t n_groups = ((int32_t *) dst->op_params)[0];
-
- int nth = 32; // SIMD width
-
- //while (nth < ne00/4 && nth < 1024) {
- // nth *= 2;
- //}
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
- [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_NORM:
- {
- GGML_ASSERT(ggml_is_contiguous_1(src0));
-
- float eps;
- memcpy(&eps, dst->op_params, sizeof(float));
-
- const int nth = MIN(256, ne00);
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
- [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
-
- const int64_t nrows = ggml_nrows(src0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ROPE:
- {
- GGML_ASSERT(ne10 == ne02);
-
- const int nth = MIN(1024, ne00);
-
- const int n_past = ((int32_t *) dst->op_params)[0];
- const int n_dims = ((int32_t *) dst->op_params)[1];
- const int mode = ((int32_t *) dst->op_params)[2];
- // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
- float freq_base;
- float freq_scale;
- float ext_factor;
- float attn_factor;
- float beta_fast;
- float beta_slow;
-
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
-
- const bool is_neox = mode & 2;
-
- id<MTLComputePipelineState> pipeline = nil;
-
- if (!is_neox) {
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
- default: GGML_ABORT("fatal error");
- };
- } else {
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
- default: GGML_ABORT("fatal error");
- };
- }
+ //while (nth < ne00/4 && nth < 1024) {
+ // nth *= 2;
+ //}
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- if (id_src2 != nil) {
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
- [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
- [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_IM2COL:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
-
- const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
- const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
- const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
- const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
- const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
- const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
-
- const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
-
- const int32_t N = src1->ne[is_2D ? 3 : 2];
- const int32_t IC = src1->ne[is_2D ? 2 : 1];
- const int32_t IH = is_2D ? src1->ne[1] : 1;
- const int32_t IW = src1->ne[0];
-
- const int32_t KH = is_2D ? src0->ne[1] : 1;
- const int32_t KW = src0->ne[0];
-
- const int32_t OH = is_2D ? dst->ne[2] : 1;
- const int32_t OW = dst->ne[1];
-
- const int32_t CHW = IC * KH * KW;
-
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
-
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (dst->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
- default: GGML_ABORT("fatal error");
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
-
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
- } break;
- case GGML_OP_UPSCALE:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- const float sf0 = (float)ne0/src0->ne[0];
- const float sf1 = (float)ne1/src0->ne[1];
- const float sf2 = (float)ne2/src0->ne[2];
- const float sf3 = (float)ne3/src0->ne[3];
-
- const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
- [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
- [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
- [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
- [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
-
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_PAD:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
-
- const int nth = MIN(1024, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ARANGE:
- {
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
- float start;
- float step;
-
- memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
- memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
- [encoder setBytes:&start length:sizeof(start) atIndex:2];
- [encoder setBytes:&step length:sizeof(step) atIndex:3];
-
- const int nth = MIN(1024, ne0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_TIMESTEP_EMBEDDING:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- const int dim = dst->op_params[0];
- const int max_period = dst->op_params[1];
-
- const int half = dim / 2;
-
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
- [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
- [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
-
- const int nth = MIN(1024, half);
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- case GGML_OP_ARGSORT:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
-
- const int nrows = ggml_nrows(src0);
-
- enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
-
- // bitonic sort requires the number of elements to be power of 2
- int64_t ne00_padded = 1;
- while (ne00_padded < ne00) {
- ne00_padded *= 2;
- }
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
- // Metal kernels require the buffer size to be multiple of 16 bytes
- // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
- const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- id<MTLComputePipelineState> pipeline = nil;
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_NORM:
+ {
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
- switch (order) {
- case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
- case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
- default: GGML_ABORT("fatal error");
- };
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ const int nth = MIN(256, ne00);
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
- } break;
- case GGML_OP_LEAKY_RELU:
- {
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
- float slope;
- memcpy(&slope, dst->op_params, sizeof(float));
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+ const int64_t nrows = ggml_nrows(src0);
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ROPE:
+ {
+ GGML_ASSERT(ne10 == ne02);
+
+ const int nth = MIN(1024, ne00);
+
+ const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ const bool is_neox = mode & 2;
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (!is_neox) {
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ };
+ } else {
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ };
+ }
- const int64_t n = ggml_nelements(dst);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ if (id_src2 != nil) {
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_IM2COL:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
- } break;
- case GGML_OP_FLASH_ATTN_EXT:
- {
- GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ne11 % 32 == 0);
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
- GGML_ASSERT(ggml_are_same_shape (src1, src2));
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
+ const int32_t IW = src1->ne[0];
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
+ const int32_t KW = src0->ne[0];
- size_t offs_src3 = 0;
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
+ const int32_t OW = dst->ne[1];
- id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+ const int32_t CHW = IC * KH * KW;
- GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
- GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
- "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
- const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
- //const int64_t ne31 = src3 ? src3->ne[1] : 0;
- const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
- const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+ id<MTLComputePipelineState> pipeline = nil;
- const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
- const uint64_t nb31 = src3 ? src3->nb[1] : 0;
- const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
- const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+ switch (dst->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ };
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+ } break;
+ case GGML_OP_UPSCALE:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ const float sf0 = (float)ne0/src0->ne[0];
+ const float sf1 = (float)ne1/src0->ne[1];
+ const float sf2 = (float)ne2/src0->ne[2];
+ const float sf3 = (float)ne3/src0->ne[3];
+
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_PAD:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ARANGE:
+ {
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
- float scale;
- float max_bias;
- float softcap;
+ float start;
+ float step;
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
- memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap));
- if (softcap != 0.0f) {
- scale /= softcap;
- }
+ memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
+ memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
- const uint32_t n_head = src0->ne[2];
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
+ [encoder setBytes:&start length:sizeof(start) atIndex:2];
+ [encoder setBytes:&step length:sizeof(step) atIndex:3];
- id<MTLComputePipelineState> pipeline = nil;
+ const int nth = MIN(1024, ne0);
- bool use_vec_kernel = false;
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
- if (ne01 >= 4 || (ne00%128 != 0)) {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
- default:
- {
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- } else {
- use_vec_kernel = true;
+ const int dim = dst->op_params[0];
+ const int max_period = dst->op_params[1];
- switch (ne00) {
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
- default:
- {
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
+ const int half = dim / 2;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
- if (id_src3) {
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
- }
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
- [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28];
-
- if (!use_vec_kernel) {
- // half8x8 kernel
- const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
-
- GGML_ASSERT(nqptg <= 32);
- GGML_ASSERT(nqptg % 8 == 0);
- GGML_ASSERT(ncpsg % 32 == 0);
-
- int64_t nsgmax = 2;
-
- while (true) {
- const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
- if (smem > ctx->device.maxThreadgroupMemoryLength) {
- break;
- }
- nsgmax *= 2;
- }
- nsgmax /= 2;
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
+ [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+
+ const int nth = MIN(1024, half);
- // simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ARGSORT:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
- const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+ const int nrows = ggml_nrows(src0);
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+ // bitonic sort requires the number of elements to be power of 2
+ int64_t ne00_padded = 1;
+ while (ne00_padded < ne00) {
+ ne00_padded *= 2;
+ }
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
- } else {
- // half1x4 kernel
- const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+ // Metal kernels require the buffer size to be multiple of 16 bytes
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
- GGML_ASSERT(nqptg <= 32);
- GGML_ASSERT(nqptg % 1 == 0);
- GGML_ASSERT(ncpsg % 32 == 0);
+ id<MTLComputePipelineState> pipeline = nil;
- // simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+ switch (order) {
+ case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
+ case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ };
- int64_t nsg = 1;
- while (nsg <= nsgt) {
- nsg *= 2;
- }
- nsg /= 2;
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
+ } break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
- GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+ float slope;
+ memcpy(&slope, dst->op_params, sizeof(float));
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
- }
- } break;
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
- {
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
- id<MTLComputePipelineState> pipeline = nil;
-
- switch (src0t) {
- case GGML_TYPE_F32:
- {
- GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
-
- switch (dstt) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
- case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
- default: GGML_ABORT("not implemented");
- };
- } break;
- case GGML_TYPE_F16:
- {
- switch (dstt) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
- default: GGML_ABORT("not implemented");
- };
- } break;
- default: GGML_ABORT("not implemented");
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ne11 % 32 == 0);
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
+
+ struct ggml_tensor * src3 = node->src[3];
+
+ size_t offs_src3 = 0;
+
+ id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+
+ GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
+ GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
+ "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+
+ const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
+ const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
+ const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+
+ const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
+ const uint64_t nb31 = src3 ? src3->nb[1] : 0;
+ const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
+ const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+ float scale;
+ float max_bias;
+ float softcap;
+
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+ memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap));
+ if (softcap != 0.0f) {
+ scale /= softcap;
+ }
+
+ const uint32_t n_head = src0->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ bool use_vec_kernel = false;
+
+ if (ne01 >= 4 || (ne00%128 != 0)) {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ } else {
+ use_vec_kernel = true;
+
+ switch (ne00) {
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ if (id_src3) {
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
+ [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28];
+
+ if (!use_vec_kernel) {
+ // half8x8 kernel
+ const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 8 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ int64_t nsgmax = 2;
+
+ while (true) {
+ const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+ if (smem > ctx->device.maxThreadgroupMemoryLength) {
+ break;
}
+ nsgmax *= 2;
+ }
+ nsgmax /= 2;
- [encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
- } break;
- default:
- {
- GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
- GGML_ABORT("fatal error");
+ // simdgroups per threadgroup (a.k.a. warps)
+ const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+
+ const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ // half1x4 kernel
+ const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 1 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ // simdgroups per threadgroup (a.k.a. warps)
+ const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+
+ int64_t nsg = 1;
+ while (nsg <= nsgt) {
+ nsg *= 2;
}
- }
+ nsg /= 2;
- if (should_capture) {
- [encoder popDebugGroup];
- }
- }
+ const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
- [encoder endEncoding];
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
- if (cb_idx < 2 || ctx->abort_callback == NULL) {
- [command_buffer commit];
- }
- });
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ }
+ } break;
+ case GGML_OP_DUP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ {
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
- // Wait for completion and check status of each command buffer
- // needed to detect if the device ran out-of-memory for example (#1881)
+ id<MTLComputePipelineState> pipeline = nil;
- for (int i = 0; i < n_cb; ++i) {
- id<MTLCommandBuffer> command_buffer = command_buffers[i];
- [command_buffer waitUntilCompleted];
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
+ case GGML_TYPE_F16:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
+ default: GGML_ABORT("not implemented");
+ }
- MTLCommandBufferStatus status = [command_buffer status];
- if (status != MTLCommandBufferStatusCompleted) {
- GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
- if (status == MTLCommandBufferStatusError) {
- NSString * error_code = [command_buffer error].localizedDescription;
- GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("%s: error: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op));
+ GGML_ABORT("fatal error");
}
+ }
+
+}
+
+static enum ggml_status ggml_metal_graph_compute(
+ struct ggml_backend_metal_context * ctx,
+ struct ggml_cgraph * gf) {
- return GGML_STATUS_FAILED;
+ // number of nodes encoded by the main thread (empirically determined)
+ const int n_main = 128;
+
+ // number of threads in addition to the main thread
+ const int n_cb = ctx->n_cb;
+
+ @autoreleasepool {
+ ctx->gf = gf;
+
+ ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
+ ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
+
+ ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
+
+ const bool should_capture = ctx->capture_next_compute;
+ if (should_capture) {
+ ctx->capture_next_compute = false;
+
+ if (!ctx->capture_started) {
+ // create capture scope
+ ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; //ctx_dev->mtl_device];
+
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
+ descriptor.captureObject = ctx->capture_scope;
+ descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
+ descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
+
+ NSError * error = nil;
+ if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
+ printf("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
+ } else {
+ [ctx->capture_scope beginScope];
+ ctx->capture_started = true;
+ }
+ }
}
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
- if (!next_buffer) {
- continue;
+ // the main thread commits the first few commands immediately
+ // command_buffer[n_cb]
+ {
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+ ctx->command_buffers[n_cb] = command_buffer;
+
+ [command_buffer enqueue];
+ ctx->encode_async(n_cb);
}
- bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
- if (next_queued) {
- continue;
+ // prepare the rest of the command buffers asynchronously
+ // command_buffer[0.. n_cb)
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+ ctx->command_buffers[cb_idx] = command_buffer;
+
+ // always enqueue the first two command buffers
+ // enqueue all of the command buffers if we don't need to abort
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
+ [command_buffer enqueue];
+ }
}
- if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
- GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
- return GGML_STATUS_ABORTED;
+ dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
+
+ // wait for completion and check status of each command buffer
+ // needed to detect if the device ran out-of-memory for example (#1881)
+ {
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
+ [command_buffer waitUntilCompleted];
+
+ MTLCommandBufferStatus status = [command_buffer status];
+ if (status != MTLCommandBufferStatusCompleted) {
+ printf("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
+ if (status == MTLCommandBufferStatusError) {
+ printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+ }
+
+ return GGML_STATUS_FAILED;
+ }
}
+ for (int i = 0; i < n_cb; ++i) {
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
+ [command_buffer waitUntilCompleted];
+
+ MTLCommandBufferStatus status = [command_buffer status];
+ if (status != MTLCommandBufferStatusCompleted) {
+ printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
+ if (status == MTLCommandBufferStatusError) {
+ printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+ }
- [next_buffer commit];
- }
+ return GGML_STATUS_FAILED;
+ }
- if (should_capture) {
- [[MTLCaptureManager sharedCaptureManager] stopCapture];
- }
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
+ if (!next_buffer) {
+ continue;
+ }
+
+ const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
+ if (next_queued) {
+ continue;
+ }
+
+ if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
+ printf("%s: command buffer %d aborted", __func__, i);
+ return GGML_STATUS_ABORTED;
+ }
+
+ [next_buffer commit];
+ }
+ if (!should_capture && ctx->capture_started) {
+ [ctx->capture_scope endScope];
+ [[MTLCaptureManager sharedCaptureManager] stopCapture];
+ }
}
+
return GGML_STATUS_SUCCESS;
}
@@ -3905,8 +3947,60 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+ if (ctx->n_cb != n_cb) {
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
+
+ if (ctx->n_cb > 2) {
+ GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
+ }
+ }
+
+ if (ctx->encode_async) {
+ Block_release(ctx->encode_async);
+ }
+
+ ctx->encode_async = Block_copy(^(size_t iter) {
+ const int cb_idx = iter;
+ const int n_cb_l = ctx->n_cb;
+
+ const int n_nodes_0 = ctx->n_nodes_0;
+ const int n_nodes_1 = ctx->n_nodes_1;
+
+ const int n_nodes_per_cb = ctx->n_nodes_per_cb;
+
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
+
+ int node_start = 0;
+ int node_end = n_nodes_0;
+
+ if (cb_idx < n_cb_l) {
+ node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
+ node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
+ }
+
+ const bool should_capture = ctx->capture_next_compute;
+
+ for (int idx = node_start; idx < node_end; ++idx) {
+ struct ggml_tensor * node = ctx->gf->nodes[idx];
+ if (should_capture) {
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(node) encoding:NSUTF8StringEncoding]];
+ }
+
+ ggml_metal_encode_node(ctx, node, encoder);
+
+ if (should_capture) {
+ [encoder popDebugGroup];
+ }
+ }
+
+ [encoder endEncoding];
+
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
+ [command_buffer commit];
+ }
+ });
- ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
}
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 89cd412a..a67ec336 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl(
yl[i] = yb[i];
}
+ device const block_q8_0 * xr = x + ib;
+
for (int row = 0; row < nr; row++) {
- device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+ //device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+ device const int8_t * qs = xr->qs + NB_Q8_0*il;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
- sumf[row] += sumq*x[ib+row*nb].d;
+ //sumf[row] += sumq*x[ib+row*nb].d;
+ sumf[row] += sumq*xr->d;
+ xr += nb;
}
yb += NB_Q8_0 * nw;
@@ -7218,10 +7223,16 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
- const half d = xb->d;
-
- for (int i = 0; i < 16; i++) {
- reg[i/4][i%4] = (qs[i + 16*il] * d);
+ if constexpr (is_same_v<type4x4, half4x4>) {
+ const half d = xb->d;
+ for (int i = 0; i < 16; i++) {
+ reg[i/4][i%4] = (half)qs[i + 16*il] * d;
+ }
+ } else {
+ const float d = xb->d;
+ for (int i = 0; i < 16; i++) {
+ reg[i/4][i%4] = qs[i + 16*il] * d;
+ }
}
}
@@ -8246,39 +8257,6 @@ kernel void kernel_mul_mm_id(
uint ntg = ntg3.x * ntg3.y * ntg3.z;
uint n = nei0*nei1;
- //uint npt = (n + ntg - 1) / ntg;
- //uint first = tiitg * npt;
- //uint last = first + npt <= n ? first + npt : n;
-
- //uint nhave = 0;
- //for (uint i = first; i < last; ++i) {
- // uint ii0 = i % nei0;
- // uint ii1 = i / nei0;
- // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
- // if (id == i02) ++nhave;
- //}
- //threadgroup uint * nums = (threadgroup uint *)shared_memory;
- //nums[tiitg] = nhave;
- //threadgroup_barrier(mem_flags::mem_threadgroup);
-
- //uint nprev = 0;
- //for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
- //int64_t _ne1 = nprev;
- //for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
-
- //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
- //for (uint i = first; i < last; ++i) {
- // uint ii0 = i % nei0;
- // uint ii1 = i / nei0;
- // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
- // if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
- //}
-
- //threadgroup_barrier(mem_flags::mem_threadgroup);
-
- //
- // The following is slightly faster than the commented out version above
- //
uint nhave = 0;
for (uint i = tiitg; i < n; i += ntg) {
uint ii0 = i % nei0;
@@ -8290,10 +8268,24 @@ kernel void kernel_mul_mm_id(
nums[tiitg] = nhave;
threadgroup_barrier(mem_flags::mem_threadgroup);
- uint nprev = 0;
- for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
- int64_t _ne1 = nprev;
- for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
+ uint stride = 1;
+ while (stride <= ntg/2) {
+ uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1;
+ if (index < ntg) nums[index] += nums[index-stride];
+ stride <<= 1;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ stride = ntg/2;
+ while (stride > 0) {
+ uint index = (tiitg+1)*stride*2 - 1;
+ if (index+stride < ntg) nums[index+stride] += nums[index];
+ stride >>= 1;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ uint _ne1 = nums[ntg-1];
+ if (!_ne1) return;
+
+ uint nprev = tiitg > 0 ? nums[tiitg-1] : 0;
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
for (uint i = tiitg; i < n; i += ntg) {
@@ -8304,47 +8296,37 @@ kernel void kernel_mul_mm_id(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
- // This is the original version that is ridiculously slow.
- //// row indices
- //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
-
- //// TODO: parallelize this loop
- //int64_t _ne1 = 0;
- //for (ushort ii1 = 0; ii1 < nei1; ii1++) {
- // for (ushort ii0 = 0; ii0 < nei0; ii0++) {
- // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
- // if (id == i02) {
- // //if (tiitg == 0) {
- // rowids[_ne1] = ushort2(ii0, ii1);
- // //}
- // _ne1++;
- // }
- // }
- //}
-
- //threadgroup_barrier(mem_flags::mem_threadgroup);
-
- kernel_mul_mm_id_impl<Dequantizer>(
- src0,
- src1,
- rowids,
- dst,
- ne00,
- ne02,
- nb01,
- nb02,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- _ne1,
- ne0*ne1,
- shared_memory,
- tgpig,
- tiitg,
- sgitg);
+ uint nstep = (_ne1 + BLOCK_SIZE_N - 1)/BLOCK_SIZE_N;
+
+ for (uint istep = 0; istep < nstep; ++istep) {
+
+ uint first = BLOCK_SIZE_N*istep;
+ uint last = first + BLOCK_SIZE_N < _ne1 ? first + BLOCK_SIZE_N : _ne1;
+ int64_t this_ne1 = last - first;
+ threadgroup ushort2 * this_rowids = rowids + istep*BLOCK_SIZE_N;
+
+ kernel_mul_mm_id_impl<Dequantizer>(
+ src0,
+ src1,
+ this_rowids,
+ dst,
+ ne00,
+ ne02,
+ nb01,
+ nb02,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ this_ne1,
+ ne0*ne1,
+ shared_memory,
+ tgpig,
+ tiitg,
+ sgitg);
+ }
}
#define QK_NL 16