From 7f968d51b4eb6f403bb7dbc1a5bbf98491ff293b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 19 Jun 2024 18:23:57 +0200 Subject: bitnet(scale in a separate tensor): mul -> scale on Metal Do the mul -> scale replacement on the fly in the Metal backend. This recovers the PP performace and cuts the TG performance degradation in half. --- ggml-metal.m | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/ggml-metal.m b/ggml-metal.m index 9911f524..d6f2df94 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1077,7 +1077,30 @@ static enum ggml_status ggml_metal_graph_compute( id pipeline = nil; - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { + float scale; + memcpy(&scale, src1->data, sizeof(float)); + //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(src0)); // src1 is a row -- cgit v1.2.3