summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m35
1 files changed, 27 insertions, 8 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 6e76f8be..c0848a29 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -728,6 +728,7 @@ static bool ggml_metal_graph_compute(
size_t offs_src0 = 0;
size_t offs_src1 = 0;
+ size_t offs_src2 = 0;
size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
@@ -746,6 +747,7 @@ static bool ggml_metal_graph_compute(
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
struct ggml_tensor * dst = gf->nodes[i];
switch (dst->op) {
@@ -807,6 +809,7 @@ static bool ggml_metal_graph_compute(
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -1188,7 +1191,16 @@ static bool ggml_metal_graph_compute(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
}
- const float scale = ((float *) dst->op_params)[0];
+ const float scale = ((float *) dst->op_params)[0];
+ const float max_bias = ((float *) dst->op_params)[1];
+
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+ const uint32_t n_head_kv = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1197,11 +1209,20 @@ static bool ggml_metal_graph_compute(
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ if (id_src2) {
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1514,8 +1535,6 @@ static bool ggml_metal_graph_compute(
// max size of the src1ids array in the kernel stack
GGML_ASSERT(ne11 <= 512);
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
-
const int64_t ne20 = src2 ? src2->ne[0] : 0;
const int64_t ne21 = src2 ? src2->ne[1] : 0;
const int64_t ne22 = src2 ? src2->ne[2] : 0;