summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-19 18:23:57 +0200
committerIwan Kawrakow <iwan.kawrakow@gmail.com>2024-06-22 12:02:52 +0300
commit7f968d51b4eb6f403bb7dbc1a5bbf98491ff293b (patch)
tree1f6c25e0449aa4684f8b30be5b05b04611c42b68
parentd08ff0df433ee9dd8643afe1cf501c4154067cd2 (diff)
bitnet(scale in a separate tensor): mul -> scale on Metal
Do the mul -> scale replacement on the fly in the Metal backend. This recovers the PP performace and cuts the TG performance degradation in half.
-rw-r--r--ggml-metal.m25
1 files changed, 24 insertions, 1 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 9911f524..d6f2df94 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1077,7 +1077,30 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil;
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) {
+ float scale;
+ memcpy(&scale, src1->data, sizeof(float));
+ //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+
+ int64_t n = ggml_nelements(dst);
+
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ break;
+ }
+ else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
// src1 is a row